1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17 
18 #include <functional>
19 #include <numeric>
20 #include <queue>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "tensorflow/compiler/xla/client/sharding_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/execution_options_util.h"
34 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/shape_inference.h"
38 #include "tensorflow/compiler/xla/util.h"
39 
40 namespace xla {
41 
42 using absl::StrCat;
43 
44 namespace {
45 
46 static const char kNameSeparator = '.';
47 
48 // Retrieves the base name of an instruction or computation fully qualified
49 // name, using separator as boundary between the initial base name part, and
50 // the numeric identification.
GetBaseName(const string & name,char separator)51 string GetBaseName(const string& name, char separator) {
52   auto pos = name.rfind(separator);
53   CHECK_NE(pos, string::npos) << name;
54   return name.substr(0, pos);
55 }
56 
57 // Generates a fully qualified computation/instruction name.
GetFullName(const string & base_name,char separator,int64 id)58 string GetFullName(const string& base_name, char separator, int64 id) {
59   const char separator_str[] = {separator, '\0'};
60   return StrCat(base_name, separator_str, id);
61 }
62 
63 // Common function to standardize setting name and IDs on computation and
64 // instruction proto entities.
65 template <typename T>
SetProtoIdAndName(T * entry,const string & base_name,char separator,int64 id)66 void SetProtoIdAndName(T* entry, const string& base_name, char separator,
67                        int64 id) {
68   entry->set_id(id);
69   entry->set_name(GetFullName(base_name, separator, id));
70 }
71 
72 }  // namespace
73 
operator -(const XlaOp & x)74 XlaOp operator-(const XlaOp& x) { return Neg(x); }
operator +(const XlaOp & x,const XlaOp & y)75 XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); }
operator -(const XlaOp & x,const XlaOp & y)76 XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); }
operator *(const XlaOp & x,const XlaOp & y)77 XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); }
operator /(const XlaOp & x,const XlaOp & y)78 XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); }
operator %(const XlaOp & x,const XlaOp & y)79 XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); }
80 
operator ~(const XlaOp & x)81 XlaOp operator~(const XlaOp& x) { return Not(x); }
operator &(const XlaOp & x,const XlaOp & y)82 XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); }
operator |(const XlaOp & x,const XlaOp & y)83 XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); }
operator ^(const XlaOp & x,const XlaOp & y)84 XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); }
operator <<(const XlaOp & x,const XlaOp & y)85 XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); }
86 
operator >>(const XlaOp & x,const XlaOp & y)87 XlaOp operator>>(const XlaOp& x, const XlaOp& y) {
88   XlaBuilder* builder = x.builder();
89   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
90     TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
91     if (!ShapeUtil::ElementIsIntegral(shape)) {
92       return InvalidArgument(
93           "Argument to >> operator does not have an integral type (%s).",
94           ShapeUtil::HumanString(shape));
95     }
96     if (ShapeUtil::ElementIsSigned(shape)) {
97       return ShiftRightArithmetic(x, y);
98     } else {
99       return ShiftRightLogical(x, y);
100     }
101   });
102 }
103 
GetShape(const XlaOp & op) const104 StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
105   TF_RETURN_IF_ERROR(first_error_);
106 
107   TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
108   return Shape(instr->shape());
109 }
110 
GetOperandShapes(absl::Span<const XlaOp> operands) const111 StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
112     absl::Span<const XlaOp> operands) const {
113   std::vector<Shape> operand_shapes;
114   for (const XlaOp& operand : operands) {
115     TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
116     operand_shapes.push_back(shape);
117   }
118   return operand_shapes;
119 }
120 
XlaBuilder(const string & computation_name)121 XlaBuilder::XlaBuilder(const string& computation_name)
122     : name_(computation_name) {}
123 
~XlaBuilder()124 XlaBuilder::~XlaBuilder() {}
125 
ReportError(const Status & error)126 XlaOp XlaBuilder::ReportError(const Status& error) {
127   CHECK(!error.ok());
128   if (die_immediately_on_error_) {
129     LOG(FATAL) << "error building computation: " << error;
130   }
131 
132   if (first_error_.ok()) {
133     first_error_ = error;
134     first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
135   }
136   return XlaOp(this);
137 }
138 
ReportErrorOrReturn(const StatusOr<XlaOp> & op)139 XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
140   if (!first_error_.ok()) {
141     return XlaOp(this);
142   }
143   if (!op.ok()) {
144     return ReportError(op.status());
145   }
146   return op.ValueOrDie();
147 }
148 
ReportErrorOrReturn(const std::function<StatusOr<XlaOp> ()> & op_creator)149 XlaOp XlaBuilder::ReportErrorOrReturn(
150     const std::function<StatusOr<XlaOp>()>& op_creator) {
151   return ReportErrorOrReturn(op_creator());
152 }
153 
GetProgramShape(int64 root_id) const154 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
155   TF_RETURN_IF_ERROR(first_error_);
156   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
157                       LookUpInstructionByHandle(root_id));
158 
159   ProgramShape program_shape;
160 
161   *program_shape.mutable_result() = Shape(root_proto->shape());
162 
163   // Check that the parameter numbers are continuous from 0, and add parameter
164   // shapes and names to the program shape.
165   const int64 param_count = parameter_numbers_.size();
166   for (int64 i = 0; i < param_count; i++) {
167     program_shape.add_parameters();
168     program_shape.add_parameter_names();
169   }
170   for (const HloInstructionProto& instr : instructions_) {
171     // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
172     // to verify continuity, we just need to verify that every parameter is in
173     // the right range.
174     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
175       const int64 index = instr.parameter_number();
176       TF_RET_CHECK(index >= 0 && index < param_count)
177           << "invalid parameter number: " << index;
178       *program_shape.mutable_parameters(index) = Shape(instr.shape());
179       *program_shape.mutable_parameter_names(index) = instr.name();
180     }
181   }
182   return program_shape;
183 }
184 
GetProgramShape() const185 StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
186   TF_RET_CHECK(!instructions_.empty());
187   return GetProgramShape(instructions_.back().id());
188 }
189 
GetProgramShape(XlaOp root) const190 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
191   if (root.builder_ != this) {
192     return InvalidArgument("Given root operation is not in this computation.");
193   }
194   return GetProgramShape(root.handle());
195 }
196 
IsConstantVisitor(const int64 op_handle,absl::flat_hash_set<int64> * visited,bool * is_constant) const197 void XlaBuilder::IsConstantVisitor(const int64 op_handle,
198                                    absl::flat_hash_set<int64>* visited,
199                                    bool* is_constant) const {
200   if (visited->contains(op_handle) || !*is_constant) {
201     return;
202   }
203 
204   const HloInstructionProto& instr =
205       *(LookUpInstructionByHandle(op_handle).ValueOrDie());
206   const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
207   switch (opcode) {
208     default:
209       for (const int64 operand_id : instr.operand_ids()) {
210         IsConstantVisitor(operand_id, visited, is_constant);
211       }
212       // TODO(b/32495713): We aren't checking the called computations.
213       break;
214     case HloOpcode::kGetDimensionSize: {
215       int64 dimension_number = instr.dimensions(0);
216       const HloInstructionProto& operand =
217           *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
218       Shape operand_shape(operand.shape());
219       if (operand_shape.is_dynamic_dimension(dimension_number)) {
220         *is_constant = false;
221       }
222       break;
223     }
224 
225     // Non functional ops.
226     case HloOpcode::kRng:
227     case HloOpcode::kAllReduce:
228       // TODO(b/33009255): Implement constant folding for cross replica sum.
229     case HloOpcode::kInfeed:
230     case HloOpcode::kOutfeed:
231     case HloOpcode::kCall:
232       // TODO(b/32495713): We aren't checking the to_apply computation itself,
233       // so we conservatively say that computations containing the Call op
234       // cannot be constant.  We cannot set is_functional=false in other similar
235       // cases since we're already relying on IsConstant to return true.
236     case HloOpcode::kCustomCall:
237     case HloOpcode::kWhile:
238       // TODO(b/32495713): We aren't checking the condition and body
239       // computations themselves.
240     case HloOpcode::kScatter:
241       // TODO(b/32495713): We aren't checking the embedded computation in
242       // Scatter.
243     case HloOpcode::kSend:
244     case HloOpcode::kRecv:
245     case HloOpcode::kParameter:
246       *is_constant = false;
247       break;
248   }
249   if (!*is_constant) {
250     VLOG(1) << "Non-constant: " << instr.name();
251   }
252   visited->insert(op_handle);
253 }
254 
SetDynamicBinding(int64 dynamic_size_param_num,ShapeIndex dynamic_size_param_index,int64 target_param_num,ShapeIndex target_param_index,int64 target_dim_num)255 Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
256                                      ShapeIndex dynamic_size_param_index,
257                                      int64 target_param_num,
258                                      ShapeIndex target_param_index,
259                                      int64 target_dim_num) {
260   bool param_exists = false;
261   for (HloInstructionProto& instr : instructions_) {
262     if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
263         instr.parameter_number() == target_param_num) {
264       param_exists = true;
265       Shape param_shape(instr.shape());
266       Shape* param_shape_ptr = &param_shape;
267       for (int64 index : target_param_index) {
268         param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
269       }
270       // TODO(b/121223198): Set `is_dynamic` to the parameter shape when XLA
271       // backend can handle dynamic dimensions.
272       *instr.mutable_shape() = param_shape.ToProto();
273     }
274   }
275 
276   if (!param_exists) {
277     return InvalidArgument(
278         "Asked to mark parameter %lld as dynamic sized parameter, but the "
279         "doesn't exists",
280         target_param_num);
281   }
282 
283   TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
284       DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
285                                                 dynamic_size_param_index},
286       DynamicParameterBinding::DynamicDimension{
287           target_param_num, target_param_index, target_dim_num}));
288   return Status::OK();
289 }
290 
BuildAndNoteError()291 XlaComputation XlaBuilder::BuildAndNoteError() {
292   DCHECK(parent_builder_ != nullptr);
293   auto build_status = Build();
294   if (!build_status.ok()) {
295     parent_builder_->ReportError(
296         AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
297     return {};
298   }
299   return build_status.ConsumeValueOrDie();
300 }
301 
GetCurrentStatus() const302 Status XlaBuilder::GetCurrentStatus() const {
303   if (!first_error_.ok()) {
304     string backtrace;
305     first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
306     return AppendStatus(first_error_, backtrace);
307   }
308   return Status::OK();
309 }
310 
Build(bool remove_dynamic_dimensions)311 StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
312   TF_RETURN_IF_ERROR(GetCurrentStatus());
313   return Build(instructions_.back().id(), remove_dynamic_dimensions);
314 }
315 
Build(XlaOp root,bool remove_dynamic_dimensions)316 StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
317                                            bool remove_dynamic_dimensions) {
318   if (root.builder_ != this) {
319     return InvalidArgument("Given root operation is not in this computation.");
320   }
321   return Build(root.handle(), remove_dynamic_dimensions);
322 }
323 
Build(int64 root_id,bool remove_dynamic_dimensions)324 StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
325                                            bool remove_dynamic_dimensions) {
326   TF_RETURN_IF_ERROR(GetCurrentStatus());
327 
328   // TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
329   // all dynamic dimensions before building xla program until we have support in
330   // the backend.
331   if (remove_dynamic_dimensions) {
332     std::function<void(ShapeProto*)> remove_dynamic_dimension =
333         [&](ShapeProto* shape) {
334           if (shape->tuple_shapes_size() != 0) {
335             for (int64 i = 0; i < shape->tuple_shapes_size(); ++i) {
336               remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
337             }
338           }
339           for (int64 i = 0; i < shape->dimensions_size(); ++i) {
340             shape->set_is_dynamic_dimension(i, false);
341           }
342         };
343 
344     for (auto& instruction : instructions_) {
345       remove_dynamic_dimension(instruction.mutable_shape());
346     }
347   }
348 
349   HloComputationProto entry;
350   SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
351   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
352   *entry.mutable_program_shape() = program_shape.ToProto();
353   entry.set_root_id(root_id);
354 
355   for (auto& instruction : instructions_) {
356     // Ensures that the instruction names are unique among the whole graph.
357     instruction.set_name(
358         GetFullName(instruction.name(), kNameSeparator, instruction.id()));
359     entry.add_instructions()->Swap(&instruction);
360   }
361 
362   XlaComputation computation(entry.id());
363   HloModuleProto* module = computation.mutable_proto();
364   module->set_name(entry.name());
365   module->set_id(entry.id());
366   module->set_entry_computation_name(entry.name());
367   module->set_entry_computation_id(entry.id());
368   *module->mutable_host_program_shape() = entry.program_shape();
369   for (auto& e : embedded_) {
370     module->add_computations()->Swap(&e.second);
371   }
372   module->add_computations()->Swap(&entry);
373   if (!input_output_aliases_.empty()) {
374     TF_RETURN_IF_ERROR(
375         PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
376   }
377   *(module->mutable_dynamic_parameter_binding()) =
378       dynamic_parameter_binding_.ToProto();
379 
380   // Clear data held by this builder.
381   this->instructions_.clear();
382   this->handle_to_index_.clear();
383   this->embedded_.clear();
384   this->parameter_numbers_.clear();
385 
386   return std::move(computation);
387 }
388 
PopulateInputOutputAlias(HloModuleProto * module,const ProgramShape & program_shape,const std::vector<InputOutputAlias> & input_output_aliases)389 /* static */ Status XlaBuilder::PopulateInputOutputAlias(
390     HloModuleProto* module, const ProgramShape& program_shape,
391     const std::vector<InputOutputAlias>& input_output_aliases) {
392   HloInputOutputAliasConfig config(program_shape.result());
393   for (auto& alias : input_output_aliases) {
394     // The HloInputOutputAliasConfig does not do parameter validation as it only
395     // carries the result shape. Maybe it should be constructed with a
396     // ProgramShape to allow full validation. We will still get an error when
397     // trying to compile the HLO module, but would be better to have validation
398     // at this stage.
399     if (alias.param_number >= program_shape.parameters_size()) {
400       return InvalidArgument("Invalid parameter number %ld (total %ld)",
401                              alias.param_number,
402                              program_shape.parameters_size());
403     }
404     const Shape& parameter_shape = program_shape.parameters(alias.param_number);
405     if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
406       return InvalidArgument("Invalid parameter %ld index: %s",
407                              alias.param_number,
408                              alias.param_index.ToString().c_str());
409     }
410     TF_RETURN_IF_ERROR(config.SetUpAlias(
411         alias.output_index, alias.param_number, alias.param_index,
412         HloInputOutputAliasConfig::AliasKind::kUserAlias));
413   }
414   *module->mutable_input_output_alias() = config.ToProto();
415   return Status::OK();
416 }
417 
InDimBroadcast(const Shape & shape,const XlaOp & operand,absl::Span<const int64> broadcast_dimensions)418 StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
419     const Shape& shape, const XlaOp& operand,
420     absl::Span<const int64> broadcast_dimensions) {
421   TF_RETURN_IF_ERROR(first_error_);
422 
423   HloInstructionProto instr;
424   *instr.mutable_shape() = shape.ToProto();
425   for (int64 dim : broadcast_dimensions) {
426     instr.add_dimensions(dim);
427   }
428   return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
429 }
430 
AddBroadcastSequence(const Shape & output_shape,const XlaOp & operand)431 StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
432                                                  const XlaOp& operand) {
433   TF_RETURN_IF_ERROR(first_error_);
434 
435   TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
436 
437   CHECK(ShapeUtil::IsScalar(operand_shape) ||
438         operand_shape.rank() == output_shape.rank());
439   Shape broadcast_shape =
440       ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type());
441 
442   // Do explicit broadcast for scalar.
443   if (ShapeUtil::IsScalar(operand_shape)) {
444     return InDimBroadcast(broadcast_shape, operand, {});
445   }
446 
447   // Do explicit broadcast for degenerate broadcast.
448   std::vector<int64> broadcast_dimensions;
449   std::vector<int64> reshaped_dimensions;
450   for (int i = 0; i < operand_shape.rank(); i++) {
451     if (operand_shape.dimensions(i) == output_shape.dimensions(i)) {
452       broadcast_dimensions.push_back(i);
453       reshaped_dimensions.push_back(operand_shape.dimensions(i));
454     } else {
455       TF_RET_CHECK(operand_shape.dimensions(i) == 1)
456           << "An explicit broadcast sequence requires the broadcasted "
457              "dimensions to be trivial; operand shape: "
458           << operand_shape << "; output_shape: " << output_shape;
459     }
460   }
461   // Eliminate the size one dimensions.
462   TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
463                       Reshape(ShapeUtil::MakeShape(operand_shape.element_type(),
464                                                    reshaped_dimensions),
465                               operand));
466   // Broadcast 'reshape' up to the larger size.
467   return InDimBroadcast(broadcast_shape, reshaped_operand,
468                         broadcast_dimensions);
469 }
470 
UnaryOp(HloOpcode unop,const XlaOp & operand)471 XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
472   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
473     HloInstructionProto instr;
474     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
475     TF_ASSIGN_OR_RETURN(Shape shape,
476                         ShapeInference::InferUnaryOpShape(unop, operand_shape));
477     *instr.mutable_shape() = shape.ToProto();
478     return AddInstruction(std::move(instr), unop, {operand});
479   });
480 }
481 
BinaryOp(HloOpcode binop,const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions,absl::optional<ComparisonDirection> direction)482 XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
483                            absl::Span<const int64> broadcast_dimensions,
484                            absl::optional<ComparisonDirection> direction) {
485   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
486     HloInstructionProto instr;
487     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
488     TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
489     TF_ASSIGN_OR_RETURN(Shape shape,
490                         ShapeInference::InferBinaryOpShape(
491                             binop, lhs_shape, rhs_shape, broadcast_dimensions));
492     *instr.mutable_shape() = shape.ToProto();
493     if (binop == HloOpcode::kCompare) {
494       if (!direction.has_value()) {
495         return InvalidArgument(
496             "kCompare expects a ComparisonDirection, but none provided.");
497       }
498       instr.set_comparison_direction(ComparisonDirectionToString(*direction));
499     } else if (direction.has_value()) {
500       return InvalidArgument(
501           "A comparison direction is provided for a non-compare opcode: %s.",
502           HloOpcodeString(binop));
503     }
504 
505     const int64 lhs_rank = lhs_shape.rank();
506     const int64 rhs_rank = rhs_shape.rank();
507 
508     XlaOp updated_lhs = lhs;
509     XlaOp updated_rhs = rhs;
510 
511     if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
512       const bool should_broadcast_lhs = lhs_rank < rhs_rank;
513       XlaOp from = should_broadcast_lhs ? lhs : rhs;
514       const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape;
515 
516       std::vector<int64> to_size;
517       std::vector<bool> to_size_is_dynamic;
518       for (int i = 0; i < shape.rank(); i++) {
519         to_size.push_back(shape.dimensions(i));
520         to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
521       }
522       for (int64 from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
523         int64 to_dim = broadcast_dimensions[from_dim];
524         to_size[to_dim] = from_shape.dimensions(from_dim);
525         to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
526       }
527 
528       const Shape& broadcasted_shape = ShapeUtil::MakeShape(
529           from_shape.element_type(), to_size, to_size_is_dynamic);
530       TF_ASSIGN_OR_RETURN(
531           XlaOp broadcasted_operand,
532           InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
533 
534       updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
535       updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
536     }
537 
538     TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs));
539     if (!ShapeUtil::SameDimensions(shape, updated_lhs_shape)) {
540       TF_ASSIGN_OR_RETURN(updated_lhs,
541                           AddBroadcastSequence(shape, updated_lhs));
542     }
543     TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs));
544     if (!ShapeUtil::SameDimensions(shape, updated_rhs_shape)) {
545       TF_ASSIGN_OR_RETURN(updated_rhs,
546                           AddBroadcastSequence(shape, updated_rhs));
547     }
548 
549     return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs});
550   });
551 }
552 
TernaryOp(HloOpcode triop,const XlaOp & lhs,const XlaOp & rhs,const XlaOp & ehs)553 XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
554                             const XlaOp& ehs) {
555   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
556     HloInstructionProto instr;
557     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
558     TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
559     TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs));
560     TF_ASSIGN_OR_RETURN(
561         Shape shape, ShapeInference::InferTernaryOpShape(triop, lhs_shape,
562                                                          rhs_shape, ehs_shape));
563     *instr.mutable_shape() = shape.ToProto();
564     XlaOp updated_lhs = lhs;
565     XlaOp updated_rhs = rhs;
566     XlaOp updated_ehs = ehs;
567     if (!shape.IsTuple()) {
568       if (!lhs_shape.IsTuple() &&
569           !ShapeUtil::SameDimensions(shape, lhs_shape)) {
570         // lhs is being implicitly broadcasted. Change to explicit.
571         TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs));
572       }
573       if (!rhs_shape.IsTuple() &&
574           !ShapeUtil::SameDimensions(shape, rhs_shape)) {
575         // rhs is being implicitly broadcasted. Change to explicit.
576         TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs));
577       }
578       if (!ehs_shape.IsTuple() &&
579           !ShapeUtil::SameDimensions(shape, ehs_shape)) {
580         // ehs is being implicitly broadcasted. Change to explicit.
581         TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs));
582       }
583     }
584     return AddInstruction(std::move(instr), triop,
585                           {updated_lhs, updated_rhs, updated_ehs});
586   });
587 }
588 
ConstantLiteral(const LiteralSlice & literal)589 XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
590   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
591     HloInstructionProto instr;
592     *instr.mutable_shape() = literal.shape().ToProto();
593     *instr.mutable_literal() = literal.ToProto();
594     return AddInstruction(std::move(instr), HloOpcode::kConstant);
595   });
596 }
597 
Iota(const Shape & shape,int64 iota_dimension)598 XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
599   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
600     HloInstructionProto instr;
601     *instr.mutable_shape() = shape.ToProto();
602     instr.add_dimensions(iota_dimension);
603     return AddInstruction(std::move(instr), HloOpcode::kIota);
604   });
605 }
606 
Iota(PrimitiveType type,int64 size)607 XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) {
608   return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
609 }
610 
Call(const XlaComputation & computation,absl::Span<const XlaOp> operands)611 XlaOp XlaBuilder::Call(const XlaComputation& computation,
612                        absl::Span<const XlaOp> operands) {
613   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
614     HloInstructionProto instr;
615     std::vector<const Shape*> operand_shape_ptrs;
616     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
617     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
618                       [](const Shape& shape) { return &shape; });
619     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
620                         computation.GetProgramShape());
621     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
622                                          operand_shape_ptrs,
623                                          /*to_apply=*/called_program_shape));
624     *instr.mutable_shape() = shape.ToProto();
625 
626     AddCalledComputation(computation, &instr);
627 
628     return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
629   });
630 }
631 
Parameter(int64 parameter_number,const Shape & shape,const string & name)632 XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
633                             const string& name) {
634   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
635     HloInstructionProto instr;
636     if (!parameter_numbers_.insert(parameter_number).second) {
637       return InvalidArgument("parameter %d already registered",
638                              parameter_number);
639     }
640     instr.set_parameter_number(parameter_number);
641     instr.set_name(name);
642     *instr.mutable_shape() = shape.ToProto();
643     return AddInstruction(std::move(instr), HloOpcode::kParameter);
644   });
645 }
646 
Broadcast(const XlaOp & operand,absl::Span<const int64> broadcast_sizes)647 XlaOp XlaBuilder::Broadcast(const XlaOp& operand,
648                             absl::Span<const int64> broadcast_sizes) {
649   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
650     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
651     TF_ASSIGN_OR_RETURN(
652         const Shape& shape,
653         ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes));
654 
655     // The client-level broadcast op just appends dimensions on the left (adds
656     // lowest numbered dimensions). The HLO broadcast instruction is more
657     // flexible and can add new dimensions anywhere. The instruction's
658     // dimensions field maps operand dimensions to dimensions in the broadcast
659     // output, so to append dimensions on the left the instruction's dimensions
660     // should just be the n highest dimension numbers of the output shape where
661     // n is the number of input dimensions.
662     const int64 operand_rank = operand_shape.rank();
663     std::vector<int64> dimensions(operand_rank);
664     for (int i = 0; i < operand_rank; ++i) {
665       dimensions[i] = i + shape.rank() - operand_rank;
666     }
667     return InDimBroadcast(shape, operand, dimensions);
668   });
669 }
670 
BroadcastInDim(const XlaOp & operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)671 XlaOp XlaBuilder::BroadcastInDim(
672     const XlaOp& operand, const absl::Span<const int64> out_dim_size,
673     const absl::Span<const int64> broadcast_dimensions) {
674   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
675     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
676     // Output shape, in the case of degenerate broadcast, the out_dim_size is
677     // not necessarily the same as the dimension sizes of the output shape.
678     auto output_shape =
679         ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size);
680     for (int i = 0; i < broadcast_dimensions.size(); i++) {
681       if (broadcast_dimensions[i] < 0 ||
682           broadcast_dimensions[i] > out_dim_size.size()) {
683         return InvalidArgument("Broadcast dimension %lld is out of bound",
684                                broadcast_dimensions[i]);
685       }
686       output_shape.set_dynamic_dimension(broadcast_dimensions[i],
687                                          operand_shape.is_dynamic_dimension(i));
688     }
689 
690     TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
691                            operand_shape, output_shape, broadcast_dimensions)
692                            .status());
693     std::vector<int64> in_dim_size(out_dim_size.begin(), out_dim_size.end());
694     for (int i = 0; i < broadcast_dimensions.size(); i++) {
695       in_dim_size[broadcast_dimensions[i]] = operand_shape.dimensions(i);
696     }
697     const auto& in_dim_shape =
698         ShapeUtil::MakeShape(operand_shape.element_type(), in_dim_size);
699     TF_ASSIGN_OR_RETURN(
700         XlaOp in_dim_broadcast,
701         InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
702 
703     // If broadcast is not degenerate, return broadcasted result.
704     if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
705       return in_dim_broadcast;
706     }
707 
708     // Otherwise handle degenerate broadcast case.
709     return AddBroadcastSequence(output_shape, in_dim_broadcast);
710   });
711 }
712 
Reshape(const Shape & shape,const XlaOp & operand)713 StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
714   TF_RETURN_IF_ERROR(first_error_);
715 
716   HloInstructionProto instr;
717   *instr.mutable_shape() = shape.ToProto();
718   return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
719 }
720 
Slice(const XlaOp & operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)721 XlaOp XlaBuilder::Slice(const XlaOp& operand,
722                         absl::Span<const int64> start_indices,
723                         absl::Span<const int64> limit_indices,
724                         absl::Span<const int64> strides) {
725   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
726     HloInstructionProto instr;
727     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
728     TF_ASSIGN_OR_RETURN(
729         Shape shape, ShapeInference::InferSliceShape(
730                          operand_shape, start_indices, limit_indices, strides));
731     *instr.mutable_shape() = shape.ToProto();
732     for (int i = 0; i < start_indices.size(); i++) {
733       auto* slice_config = instr.add_slice_dimensions();
734       slice_config->set_start(start_indices[i]);
735       slice_config->set_limit(limit_indices[i]);
736       slice_config->set_stride(strides[i]);
737     }
738 
739     return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
740   });
741 }
742 
SliceInDim(const XlaOp & operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)743 XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
744                              int64 limit_index, int64 stride, int64 dimno) {
745   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
746     TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
747     std::vector<int64> starts(shape.rank(), 0);
748     std::vector<int64> limits(shape.dimensions().begin(),
749                               shape.dimensions().end());
750     std::vector<int64> strides(shape.rank(), 1);
751     starts[dimno] = start_index;
752     limits[dimno] = limit_index;
753     strides[dimno] = stride;
754     return Slice(operand, starts, limits, strides);
755   });
756 }
757 
DynamicSlice(const XlaOp & operand,const XlaOp & start_indices,absl::Span<const int64> slice_sizes)758 XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
759                                absl::Span<const int64> slice_sizes) {
760   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
761     HloInstructionProto instr;
762 
763     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
764     TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
765                         GetShape(start_indices));
766     TF_ASSIGN_OR_RETURN(Shape shape,
767                         ShapeInference::InferDynamicSliceShape(
768                             operand_shape, {start_indices_shape}, slice_sizes));
769     *instr.mutable_shape() = shape.ToProto();
770 
771     for (int64 size : slice_sizes) {
772       instr.add_dynamic_slice_sizes(size);
773     }
774 
775     return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice,
776                           {operand, start_indices});
777   });
778 }
779 
DynamicSlice(const XlaOp & operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)780 XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand,
781                                absl::Span<const XlaOp> start_indices,
782                                absl::Span<const int64> slice_sizes) {
783   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
784     HloInstructionProto instr;
785 
786     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
787     std::vector<const Shape*> start_indices_shape_ptrs;
788     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
789                         GetOperandShapes(start_indices));
790     absl::c_transform(start_indices_shapes,
791                       std::back_inserter(start_indices_shape_ptrs),
792                       [](const Shape& shape) { return &shape; });
793     TF_ASSIGN_OR_RETURN(Shape shape,
794                         ShapeInference::InferDynamicSliceShape(
795                             operand_shape, start_indices_shapes, slice_sizes));
796     *instr.mutable_shape() = shape.ToProto();
797 
798     for (int64 size : slice_sizes) {
799       instr.add_dynamic_slice_sizes(size);
800     }
801 
802     std::vector<XlaOp> operands = {operand};
803     operands.insert(operands.end(), start_indices.begin(), start_indices.end());
804     return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
805   });
806 }
807 
DynamicUpdateSlice(const XlaOp & operand,const XlaOp & update,const XlaOp & start_indices)808 XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
809                                      const XlaOp& start_indices) {
810   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
811     HloInstructionProto instr;
812 
813     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
814     TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
815     TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
816                         GetShape(start_indices));
817     TF_ASSIGN_OR_RETURN(
818         Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
819                          operand_shape, update_shape, {start_indices_shape}));
820     *instr.mutable_shape() = shape.ToProto();
821 
822     return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
823                           {operand, update, start_indices});
824   });
825 }
826 
DynamicUpdateSlice(const XlaOp & operand,const XlaOp & update,absl::Span<const XlaOp> start_indices)827 XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
828                                      absl::Span<const XlaOp> start_indices) {
829   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
830     HloInstructionProto instr;
831 
832     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
833     TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
834     std::vector<const Shape*> start_indices_shape_ptrs;
835     TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
836                         GetOperandShapes(start_indices));
837     absl::c_transform(start_indices_shapes,
838                       std::back_inserter(start_indices_shape_ptrs),
839                       [](const Shape& shape) { return &shape; });
840     TF_ASSIGN_OR_RETURN(Shape shape,
841                         ShapeInference::InferDynamicUpdateSliceShape(
842                             operand_shape, update_shape, start_indices_shapes));
843     *instr.mutable_shape() = shape.ToProto();
844 
845     std::vector<XlaOp> operands = {operand, update};
846     operands.insert(operands.end(), start_indices.begin(), start_indices.end());
847     return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
848                           operands);
849   });
850 }
851 
ConcatInDim(absl::Span<const XlaOp> operands,int64 dimension)852 XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
853                               int64 dimension) {
854   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
855     HloInstructionProto instr;
856 
857     std::vector<const Shape*> operand_shape_ptrs;
858     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
859     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
860                       [](const Shape& shape) { return &shape; });
861     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
862                                          operand_shape_ptrs, dimension));
863     *instr.mutable_shape() = shape.ToProto();
864 
865     instr.add_dimensions(dimension);
866 
867     return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
868   });
869 }
870 
Pad(const XlaOp & operand,const XlaOp & padding_value,const PaddingConfig & padding_config)871 XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
872                       const PaddingConfig& padding_config) {
873   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
874     HloInstructionProto instr;
875 
876     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
877     TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape,
878                         GetShape(padding_value));
879     TF_ASSIGN_OR_RETURN(
880         Shape shape, ShapeInference::InferPadShape(
881                          operand_shape, padding_value_shape, padding_config));
882     *instr.mutable_shape() = shape.ToProto();
883     *instr.mutable_padding_config() = padding_config;
884 
885     return AddInstruction(std::move(instr), HloOpcode::kPad,
886                           {operand, padding_value});
887   });
888 }
889 
Reshape(const XlaOp & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)890 XlaOp XlaBuilder::Reshape(const XlaOp& operand,
891                           absl::Span<const int64> dimensions,
892                           absl::Span<const int64> new_sizes) {
893   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
894     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
895     TF_ASSIGN_OR_RETURN(const Shape shape,
896                         ShapeInference::InferReshapeShape(
897                             operand_shape, dimensions, new_sizes));
898     XlaOp transposed = IsIdentityPermutation(dimensions)
899                            ? operand
900                            : Transpose(operand, dimensions);
901     return Reshape(shape, transposed);
902   });
903 }
904 
Reshape(const XlaOp & operand,absl::Span<const int64> new_sizes)905 XlaOp XlaBuilder::Reshape(const XlaOp& operand,
906                           absl::Span<const int64> new_sizes) {
907   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
908     TF_ASSIGN_OR_RETURN(Shape shape, GetShape(operand));
909     std::vector<int64> dimensions(shape.dimensions_size());
910     std::iota(dimensions.begin(), dimensions.end(), 0);
911     return Reshape(operand, dimensions, new_sizes);
912   });
913 }
914 
Collapse(const XlaOp & operand,absl::Span<const int64> dimensions)915 XlaOp XlaBuilder::Collapse(const XlaOp& operand,
916                            absl::Span<const int64> dimensions) {
917   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
918     if (dimensions.size() <= 1) {
919       // Not collapsing anything, trivially we can return the operand versus
920       // enqueueing a trivial reshape.
921       return operand;
922     }
923 
924     // Out-of-order collapse is not supported.
925     // Checks that the collapsed dimensions are in order and consecutive.
926     for (absl::Span<const int64>::size_type i = 1; i < dimensions.size(); ++i) {
927       if (dimensions[i] - 1 != dimensions[i - 1]) {
928         return InvalidArgument(
929             "Collapsed dimensions are not in consecutive order.");
930       }
931     }
932 
933     // Create a new sizes vector from the old shape, replacing the collapsed
934     // dimensions by the product of their sizes.
935     TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand));
936 
937     VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape);
938     VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
939 
940     std::vector<int64> new_sizes;
941     for (int i = 0; i < original_shape.rank(); ++i) {
942       if (i <= dimensions.front() || i > dimensions.back()) {
943         new_sizes.push_back(original_shape.dimensions(i));
944       } else {
945         new_sizes.back() *= original_shape.dimensions(i);
946       }
947     }
948 
949     VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
950 
951     return Reshape(operand, new_sizes);
952   });
953 }
954 
Trace(const string & tag,const XlaOp & operand)955 void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
956   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
957     HloInstructionProto instr;
958     *instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
959     *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
960     return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
961   });
962 }
963 
Select(const XlaOp & pred,const XlaOp & on_true,const XlaOp & on_false)964 XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
965                          const XlaOp& on_false) {
966   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
967     TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true));
968     TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false));
969     TF_RET_CHECK(true_shape.IsTuple() == false_shape.IsTuple());
970     HloOpcode opcode =
971         true_shape.IsTuple() ? HloOpcode::kTupleSelect : HloOpcode::kSelect;
972     return TernaryOp(opcode, pred, on_true, on_false);
973   });
974 }
975 
Tuple(absl::Span<const XlaOp> elements)976 XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
977   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
978     HloInstructionProto instr;
979     std::vector<const Shape*> operand_shape_ptrs;
980     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
981     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
982                       [](const Shape& shape) { return &shape; });
983     TF_ASSIGN_OR_RETURN(const Shape shape,
984                         ShapeInference::InferVariadicOpShape(
985                             HloOpcode::kTuple, operand_shape_ptrs));
986     *instr.mutable_shape() = shape.ToProto();
987     return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
988   });
989 }
990 
GetTupleElement(const XlaOp & tuple_data,int64 index)991 XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
992   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
993     HloInstructionProto instr;
994     TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data));
995     if (!tuple_shape.IsTuple()) {
996       return InvalidArgument(
997           "Operand to GetTupleElement() is not a tuple; got %s",
998           ShapeUtil::HumanString(tuple_shape));
999     }
1000     *instr.mutable_shape() =
1001         ShapeUtil::GetTupleElementShape(tuple_shape, index).ToProto();
1002 
1003     instr.set_tuple_index(index);
1004 
1005     return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
1006                           {tuple_data});
1007   });
1008 }
1009 
Dot(const XlaOp & lhs,const XlaOp & rhs,const PrecisionConfig * precision_config)1010 XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
1011                       const PrecisionConfig* precision_config) {
1012   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1013     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
1014 
1015     DotDimensionNumbers dimension_numbers;
1016     dimension_numbers.add_lhs_contracting_dimensions(
1017         lhs_shape.dimensions_size() == 1 ? 0 : 1);
1018     dimension_numbers.add_rhs_contracting_dimensions(0);
1019     return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
1020   });
1021 }
1022 
DotGeneral(const XlaOp & lhs,const XlaOp & rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)1023 XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
1024                              const DotDimensionNumbers& dimension_numbers,
1025                              const PrecisionConfig* precision_config) {
1026   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1027     HloInstructionProto instr;
1028     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
1029     TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
1030     // If one operand is a scalar, just multiply the two operands.
1031     if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
1032       if (dimension_numbers.rhs_batch_dimensions_size() != 0 ||
1033           dimension_numbers.lhs_batch_dimensions_size() != 0 ||
1034           dimension_numbers.rhs_contracting_dimensions_size() != 0 ||
1035           dimension_numbers.lhs_contracting_dimensions_size() != 0) {
1036         return InvalidArgument(
1037             "Dots with scalar operands must have no contracting or batch "
1038             "dimensions");
1039       }
1040       return xla::Mul(lhs, rhs);
1041     }
1042     TF_ASSIGN_OR_RETURN(Shape shape,
1043                         ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
1044                                                         dimension_numbers));
1045     *instr.mutable_shape() = shape.ToProto();
1046     *instr.mutable_dot_dimension_numbers() = dimension_numbers;
1047     if (precision_config != nullptr) {
1048       *instr.mutable_precision_config() = *precision_config;
1049     }
1050     return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
1051   });
1052 }
1053 
VerifyConvolution(const Shape & lhs_shape,const Shape & rhs_shape,const ConvolutionDimensionNumbers & dimension_numbers) const1054 Status XlaBuilder::VerifyConvolution(
1055     const Shape& lhs_shape, const Shape& rhs_shape,
1056     const ConvolutionDimensionNumbers& dimension_numbers) const {
1057   if (lhs_shape.rank() != rhs_shape.rank()) {
1058     return InvalidArgument(
1059         "Convolution arguments must have same number of "
1060         "dimensions. Got: %s and %s",
1061         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1062   }
1063   int num_dims = lhs_shape.rank();
1064   if (num_dims < 2) {
1065     return InvalidArgument(
1066         "Convolution expects argument arrays with >= 3 dimensions. "
1067         "Got: %s and %s",
1068         ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
1069   }
1070   int num_spatial_dims = num_dims - 2;
1071 
1072   const auto check_spatial_dimensions =
1073       [&](const char* const field_name,
1074           const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
1075               numbers) {
1076         if (numbers.size() != num_spatial_dims) {
1077           return InvalidArgument("Expected %d elements for %s, but got %d.",
1078                                  num_spatial_dims, field_name, numbers.size());
1079         }
1080         for (int i = 0; i < numbers.size(); ++i) {
1081           if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
1082             return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
1083                                    field_name, i, numbers.Get(i));
1084           }
1085         }
1086         return Status::OK();
1087       };
1088   TF_RETURN_IF_ERROR(
1089       check_spatial_dimensions("input_spatial_dimensions",
1090                                dimension_numbers.input_spatial_dimensions()));
1091   TF_RETURN_IF_ERROR(
1092       check_spatial_dimensions("kernel_spatial_dimensions",
1093                                dimension_numbers.kernel_spatial_dimensions()));
1094   return check_spatial_dimensions(
1095       "output_spatial_dimensions",
1096       dimension_numbers.output_spatial_dimensions());
1097 }
1098 
Conv(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1099 XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
1100                        absl::Span<const int64> window_strides, Padding padding,
1101                        int64 feature_group_count, int64 batch_group_count,
1102                        const PrecisionConfig* precision_config) {
1103   return ConvWithGeneralDimensions(
1104       lhs, rhs, window_strides, padding,
1105       CreateDefaultConvDimensionNumbers(window_strides.size()),
1106       feature_group_count, batch_group_count, precision_config);
1107 }
1108 
ConvWithGeneralPadding(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1109 XlaOp XlaBuilder::ConvWithGeneralPadding(
1110     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1111     absl::Span<const std::pair<int64, int64>> padding,
1112     int64 feature_group_count, int64 batch_group_count,
1113     const PrecisionConfig* precision_config) {
1114   return ConvGeneral(lhs, rhs, window_strides, padding,
1115                      CreateDefaultConvDimensionNumbers(window_strides.size()),
1116                      feature_group_count, batch_group_count, precision_config);
1117 }
1118 
ConvWithGeneralDimensions(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1119 XlaOp XlaBuilder::ConvWithGeneralDimensions(
1120     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1121     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1122     int64 feature_group_count, int64 batch_group_count,
1123     const PrecisionConfig* precision_config) {
1124   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1125     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
1126     TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
1127 
1128     TF_RETURN_IF_ERROR(
1129         VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers));
1130 
1131     std::vector<int64> base_area_dimensions(
1132         dimension_numbers.input_spatial_dimensions_size());
1133     for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
1134          ++i) {
1135       base_area_dimensions[i] =
1136           lhs_shape.dimensions(dimension_numbers.input_spatial_dimensions(i));
1137     }
1138 
1139     std::vector<int64> window_dimensions(
1140         dimension_numbers.kernel_spatial_dimensions_size());
1141     for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1142          ++i) {
1143       window_dimensions[i] =
1144           rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1145     }
1146 
1147     return ConvGeneral(lhs, rhs, window_strides,
1148                        MakePadding(base_area_dimensions, window_dimensions,
1149                                    window_strides, padding),
1150                        dimension_numbers, feature_group_count,
1151                        batch_group_count, precision_config);
1152   });
1153 }
1154 
ConvGeneral(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1155 XlaOp XlaBuilder::ConvGeneral(
1156     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1157     absl::Span<const std::pair<int64, int64>> padding,
1158     const ConvolutionDimensionNumbers& dimension_numbers,
1159     int64 feature_group_count, int64 batch_group_count,
1160     const PrecisionConfig* precision_config) {
1161   return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
1162                             dimension_numbers, feature_group_count,
1163                             batch_group_count, precision_config);
1164 }
1165 
ConvGeneralDilated(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)1166 XlaOp XlaBuilder::ConvGeneralDilated(
1167     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1168     absl::Span<const std::pair<int64, int64>> padding,
1169     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1170     const ConvolutionDimensionNumbers& dimension_numbers,
1171     int64 feature_group_count, int64 batch_group_count,
1172     const PrecisionConfig* precision_config) {
1173   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1174     HloInstructionProto instr;
1175     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
1176     TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
1177     TF_RETURN_IF_ERROR(
1178         VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers));
1179 
1180     std::vector<int64> window_dimensions(
1181         dimension_numbers.kernel_spatial_dimensions_size());
1182     for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
1183          ++i) {
1184       window_dimensions[i] =
1185           rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i));
1186     }
1187     TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
1188                         MakeWindow(window_dimensions, window_strides, padding,
1189                                    lhs_dilation, rhs_dilation));
1190 
1191     TF_ASSIGN_OR_RETURN(
1192         Shape shape, ShapeInference::InferConvolveShape(
1193                          lhs_shape, rhs_shape, feature_group_count,
1194                          batch_group_count, instr.window(), dimension_numbers));
1195     *instr.mutable_shape() = shape.ToProto();
1196 
1197     *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
1198     instr.set_feature_group_count(feature_group_count);
1199     instr.set_batch_group_count(batch_group_count);
1200 
1201     if (precision_config != nullptr) {
1202       *instr.mutable_precision_config() = *precision_config;
1203     }
1204 
1205     return AddInstruction(std::move(instr), HloOpcode::kConvolution,
1206                           {lhs, rhs});
1207   });
1208 }
1209 
MakeWindow(absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation) const1210 StatusOr<Window> XlaBuilder::MakeWindow(
1211     absl::Span<const int64> window_dimensions,
1212     absl::Span<const int64> window_strides,
1213     absl::Span<const std::pair<int64, int64>> padding,
1214     absl::Span<const int64> lhs_dilation,
1215     absl::Span<const int64> rhs_dilation) const {
1216   const auto verify_size = [&](const size_t x, const char* x_name) {
1217     if (x == 0 || x == window_dimensions.size()) {
1218       return Status::OK();
1219     } else {
1220       return InvalidArgument(
1221           "%s", absl::StrCat(
1222                     "Window has different number of window dimensions than of ",
1223                     x_name,
1224                     "\nNumber of window dimensions: ", window_dimensions.size(),
1225                     "\nNumber of ", x_name, ": ", x, "\n"));
1226     }
1227   };
1228   TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
1229   TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
1230   TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
1231   TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
1232 
1233   Window window;
1234   for (size_t i = 0; i < window_dimensions.size(); i++) {
1235     auto dim = window.add_dimensions();
1236     dim->set_size(window_dimensions[i]);
1237     if (!window_strides.empty()) {
1238       dim->set_stride(window_strides[i]);
1239     } else {
1240       dim->set_stride(1);
1241     }
1242     if (!padding.empty()) {
1243       dim->set_padding_low(padding[i].first);
1244       dim->set_padding_high(padding[i].second);
1245     } else {
1246       dim->set_padding_low(0);
1247       dim->set_padding_high(0);
1248     }
1249     if (!lhs_dilation.empty()) {
1250       dim->set_base_dilation(lhs_dilation[i]);
1251     } else {
1252       dim->set_base_dilation(1);
1253     }
1254     if (!rhs_dilation.empty()) {
1255       dim->set_window_dilation(rhs_dilation[i]);
1256     } else {
1257       dim->set_window_dilation(1);
1258     }
1259     dim->set_window_reversal(false);
1260   }
1261   return window;
1262 }
1263 
Fft(const XlaOp & operand,const FftType fft_type,const absl::Span<const int64> fft_length)1264 XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
1265                       const absl::Span<const int64> fft_length) {
1266   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1267     HloInstructionProto instr;
1268     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1269     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
1270                                          operand_shape, fft_type, fft_length));
1271     *instr.mutable_shape() = shape.ToProto();
1272     instr.set_fft_type(fft_type);
1273     for (int64 i : fft_length) {
1274       instr.add_fft_length(i);
1275     }
1276 
1277     return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
1278   });
1279 }
1280 
Infeed(const Shape & shape,const string & config)1281 XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
1282   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1283     HloInstructionProto instr;
1284     if (!LayoutUtil::HasLayout(shape)) {
1285       return InvalidArgument("Given shape to Infeed must have a layout");
1286     }
1287     const Shape infeed_instruction_shape =
1288         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1289     *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1290     instr.set_infeed_config(config);
1291 
1292     if (shape.IsArray() && sharding() &&
1293         sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) {
1294       // TODO(b/110793772): Support tiled array-shaped infeeds.
1295       return InvalidArgument(
1296           "Tiled sharding is not yet supported for array-shaped infeeds");
1297     }
1298 
1299     if (sharding() &&
1300         sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
1301       return InvalidArgument(
1302           "Replicated sharding is not yet supported for infeeds");
1303     }
1304 
1305     // Infeed takes a single token operand. Generate the token to pass to the
1306     // infeed.
1307     XlaOp token;
1308     auto make_token = [&]() {
1309       HloInstructionProto token_instr;
1310       *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1311       return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
1312     };
1313     if (sharding()) {
1314       // Arbitrarily assign token to device 0.
1315       OpSharding sharding = sharding_builder::AssignDevice(0);
1316       XlaScopedShardingAssignment scoped_sharding(this, sharding);
1317       TF_ASSIGN_OR_RETURN(token, make_token());
1318     } else {
1319       TF_ASSIGN_OR_RETURN(token, make_token());
1320     }
1321 
1322     // The sharding is set by the client according to the data tuple shape.
1323     // However, the shape of the infeed instruction is a tuple containing the
1324     // data and a token. For tuple sharding type, the sharding must be changed
1325     // to accommodate the token.
1326     XlaOp infeed;
1327     if (sharding() &&
1328         sharding()->type() == OpSharding::Type::OpSharding_Type_TUPLE) {
1329       // TODO(b/80000000): Remove this when clients have been updated to handle
1330       // tokens.
1331       OpSharding infeed_instruction_sharding = *sharding();
1332       // Arbitrarily assign the token to device 0.
1333       *infeed_instruction_sharding.add_tuple_shardings() =
1334           sharding_builder::AssignDevice(0);
1335       XlaScopedShardingAssignment scoped_sharding(this,
1336                                                   infeed_instruction_sharding);
1337       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1338                                                  HloOpcode::kInfeed, {token}));
1339     } else {
1340       TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
1341                                                  HloOpcode::kInfeed, {token}));
1342     }
1343 
1344     // The infeed instruction produces a tuple of the infed data and a token
1345     // type. Return XLA op containing the data.
1346     // TODO(b/80000000): Remove this when clients have been updated to handle
1347     // tokens.
1348     HloInstructionProto infeed_data;
1349     *infeed_data.mutable_shape() = shape.ToProto();
1350     infeed_data.set_tuple_index(0);
1351     return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
1352                           {infeed});
1353   });
1354 }
1355 
InfeedWithToken(const XlaOp & token,const Shape & shape,const string & config)1356 XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape,
1357                                   const string& config) {
1358   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1359     HloInstructionProto instr;
1360     if (!LayoutUtil::HasLayout(shape)) {
1361       return InvalidArgument("Given shape to Infeed must have a layout");
1362     }
1363     const Shape infeed_instruction_shape =
1364         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
1365     *instr.mutable_shape() = infeed_instruction_shape.ToProto();
1366     instr.set_infeed_config(config);
1367 
1368     if (shape.IsArray() && sharding() &&
1369         sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) {
1370       // TODO(b/110793772): Support tiled array-shaped infeeds.
1371       return InvalidArgument(
1372           "Tiled sharding is not yet supported for array-shaped infeeds");
1373     }
1374 
1375     if (sharding() &&
1376         sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
1377       return InvalidArgument(
1378           "Replicated sharding is not yet supported for infeeds");
1379     }
1380 
1381     return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
1382   });
1383 }
1384 
Outfeed(const XlaOp & operand,const Shape & shape_with_layout,const string & outfeed_config)1385 void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
1386                          const string& outfeed_config) {
1387   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1388     HloInstructionProto instr;
1389 
1390     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1391 
1392     // Check and set outfeed shape.
1393     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1394       return InvalidArgument("Given shape to Outfeed must have a layout");
1395     }
1396     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1397     if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
1398       return InvalidArgument(
1399           "Outfeed shape %s must be compatible with operand shape %s",
1400           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1401           ShapeUtil::HumanStringWithLayout(operand_shape));
1402     }
1403     *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1404 
1405     instr.set_outfeed_config(outfeed_config);
1406 
1407     // Outfeed takes a token as its second operand. Generate the token to pass
1408     // to the outfeed.
1409     HloInstructionProto token_instr;
1410     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1411     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
1412                                                     HloOpcode::kAfterAll, {}));
1413 
1414     TF_RETURN_IF_ERROR(
1415         AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
1416             .status());
1417 
1418     // The outfeed instruction produces a token. However, existing users expect
1419     // a nil shape (empty tuple). This should only be relevant if the outfeed is
1420     // the root of a computation.
1421     // TODO(b/80000000): Remove this when clients have been updated to handle
1422     // tokens.
1423     HloInstructionProto tuple_instr;
1424     *tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
1425 
1426     // The dummy tuple should have no sharding.
1427     {
1428       XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
1429       TF_ASSIGN_OR_RETURN(
1430           XlaOp empty_tuple,
1431           AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
1432       return empty_tuple;
1433     }
1434   });
1435 }
1436 
OutfeedWithToken(const XlaOp & operand,const XlaOp & token,const Shape & shape_with_layout,const string & outfeed_config)1437 XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
1438                                    const Shape& shape_with_layout,
1439                                    const string& outfeed_config) {
1440   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1441     HloInstructionProto instr;
1442 
1443     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1444 
1445     // Check and set outfeed shape.
1446     if (!LayoutUtil::HasLayout(shape_with_layout)) {
1447       return InvalidArgument("Given shape to Outfeed must have a layout");
1448     }
1449     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1450     if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
1451       return InvalidArgument(
1452           "Outfeed shape %s must be compatible with operand shape %s",
1453           ShapeUtil::HumanStringWithLayout(shape_with_layout),
1454           ShapeUtil::HumanStringWithLayout(operand_shape));
1455     }
1456     *instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
1457 
1458     instr.set_outfeed_config(outfeed_config);
1459 
1460     return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
1461                           {operand, token});
1462   });
1463 }
1464 
CreateToken()1465 XlaOp XlaBuilder::CreateToken() {
1466   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1467     HloInstructionProto instr;
1468     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1469     return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
1470   });
1471 }
1472 
AfterAll(absl::Span<const XlaOp> tokens)1473 XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
1474   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1475     if (tokens.empty()) {
1476       return InvalidArgument("AfterAll requires at least one operand");
1477     }
1478     for (int i = 0; i < tokens.size(); ++i) {
1479       const XlaOp& operand = tokens[i];
1480       TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1481       if (!operand_shape.IsToken()) {
1482         return InvalidArgument(
1483             "All operands to AfterAll must be tokens; operand %d has shape %s",
1484             i, ShapeUtil::HumanString(operand_shape));
1485       }
1486     }
1487     HloInstructionProto instr;
1488     *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
1489     return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
1490   });
1491 }
1492 
CustomCall(const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque,absl::optional<absl::Span<const Shape>> operand_shapes_with_layout)1493 XlaOp XlaBuilder::CustomCall(
1494     const string& call_target_name, absl::Span<const XlaOp> operands,
1495     const Shape& shape, const string& opaque,
1496     absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
1497   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1498     HloInstructionProto instr;
1499     if (absl::StartsWith(call_target_name, "$")) {
1500       return InvalidArgument(
1501           "Invalid custom_call_target \"%s\": Call targets that start with '$' "
1502           "are reserved for internal use.",
1503           call_target_name);
1504     }
1505     *instr.mutable_shape() = shape.ToProto();
1506     instr.set_custom_call_target(call_target_name);
1507     instr.set_custom_call_opaque(opaque);
1508     if (operand_shapes_with_layout.has_value()) {
1509       if (!LayoutUtil::HasLayout(shape)) {
1510         return InvalidArgument(
1511             "Result shape must have layout for custom call with constrained "
1512             "layout.");
1513       }
1514       if (operands.size() != operand_shapes_with_layout->size()) {
1515         return InvalidArgument(
1516             "Must specify a shape with layout for each operand for custom call "
1517             "with constrained layout; given %d shapes, expected %d",
1518             operand_shapes_with_layout->size(), operands.size());
1519       }
1520       instr.set_constrain_layout(true);
1521       int64 operand_num = 0;
1522       for (const Shape& operand_shape : *operand_shapes_with_layout) {
1523         if (!LayoutUtil::HasLayout(operand_shape)) {
1524           return InvalidArgument(
1525               "No layout specified for operand %d for custom call with "
1526               "constrained layout.",
1527               operand_num);
1528         }
1529         *instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
1530         ++operand_num;
1531       }
1532     }
1533     return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
1534   });
1535 }
1536 
Transpose(const XlaOp & operand,absl::Span<const int64> permutation)1537 XlaOp XlaBuilder::Transpose(const XlaOp& operand,
1538                             absl::Span<const int64> permutation) {
1539   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1540     HloInstructionProto instr;
1541     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1542     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
1543                                          operand_shape, permutation));
1544     *instr.mutable_shape() = shape.ToProto();
1545     for (int64 dim : permutation) {
1546       instr.add_dimensions(dim);
1547     }
1548     return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
1549   });
1550 }
1551 
Rev(const XlaOp & operand,absl::Span<const int64> dimensions)1552 XlaOp XlaBuilder::Rev(const XlaOp& operand,
1553                       absl::Span<const int64> dimensions) {
1554   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1555     HloInstructionProto instr;
1556     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1557     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
1558                                          operand_shape, dimensions));
1559     *instr.mutable_shape() = shape.ToProto();
1560     for (int64 dim : dimensions) {
1561       instr.add_dimensions(dim);
1562     }
1563     return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
1564   });
1565 }
1566 
1567 namespace {
1568 // Switch from a floating point value to a integer value in such a way that when
1569 // using the integer value to compare, we get the same result for normal values,
1570 // and -Nan is treated as the smallest value, and Nan is treated as the largest
1571 // value.
1572 // If f is a float, and
1573 // x = bit_cast<int32>(f);
1574 // y = x < 0 ? numeric_limits<int32>::max() - x : x;
1575 // then y is ordered as an int32 such that finite values have the obvious order,
1576 // -0 is ordered before 0, and -NaN and NaN appear at the beginning and end of
1577 // the ordering.
1578 // Note that in order to avoid -x to overflow, we calculate
1579 // numeric_limits<int32>::max() - x as unsigned, and then convert back to
1580 // signed.
BitcastConvertFloatingPointToIntegral(const XlaOp & value,int64 bit_width)1581 XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value,
1582                                             int64 bit_width) {
1583   PrimitiveType signed_type;
1584   PrimitiveType unsigned_type;
1585   XlaOp max_value;
1586   switch (bit_width) {
1587     case 16:
1588       max_value =
1589           ConstantR0(value.builder(),
1590                      static_cast<uint16>(std::numeric_limits<int16>::max()));
1591       signed_type = S16;
1592       unsigned_type = U16;
1593       break;
1594     case 32:
1595       max_value =
1596           ConstantR0(value.builder(),
1597                      static_cast<uint32>(std::numeric_limits<int32>::max()));
1598       signed_type = S32;
1599       unsigned_type = U32;
1600       break;
1601     case 64:
1602       max_value =
1603           ConstantR0(value.builder(),
1604                      static_cast<uint64>(std::numeric_limits<int64>::max()));
1605       signed_type = S64;
1606       unsigned_type = U64;
1607       break;
1608     default:
1609       return value.builder()->ReportError(
1610           InvalidArgument("Invalid bit width %lld for Comparator floating "
1611                           "point parameter.",
1612                           bit_width));
1613   }
1614   auto signed_value = BitcastConvertType(value, signed_type);
1615   auto unsigned_value = BitcastConvertType(value, unsigned_type);
1616   auto flipped_value =
1617       BitcastConvertType(Sub(max_value, unsigned_value), signed_type);
1618   auto is_negative =
1619       Lt(signed_value,
1620          ConstantLiteral(value.builder(), LiteralUtil::Zero(signed_type)));
1621   return Select(is_negative, flipped_value, signed_value);
1622 }
1623 }  // namespace
1624 
Sort(const XlaOp & keys,absl::Span<const XlaOp> values,int64 dimension)1625 XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
1626                        int64 dimension) {
1627   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1628     std::vector<XlaOp> operands{keys};
1629     for (const XlaOp& value : values) {
1630       operands.push_back(value);
1631     }
1632     // Build the default less-than comparator (copied from lib/comparators.cc).
1633     // TODO(b/122298745): Remove the deprecated API method so that this code
1634     // duplication can be deleted.
1635     auto b = this->CreateSubBuilder("comparator");
1636     std::vector<PrimitiveType> operand_types;
1637     for (const XlaOp& operand : operands) {
1638       TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(operand));
1639       operand_types.push_back(operand_shape.element_type());
1640     }
1641 
1642     int64 parameter_count = 0;
1643     XlaOp first_lhs_param;
1644     XlaOp first_rhs_param;
1645 
1646     for (auto operand_type : operand_types) {
1647       auto scalar_shape = ShapeUtil::MakeShape(operand_type, {});
1648       auto lhs_param =
1649           b->Parameter(parameter_count * 2, scalar_shape,
1650                        absl::StrCat("p.", parameter_count, ".lhs"));
1651       auto rhs_param =
1652           b->Parameter(parameter_count * 2 + 1, scalar_shape,
1653                        absl::StrCat("p.", parameter_count, ".rhs"));
1654       if (parameter_count == 0) {
1655         first_lhs_param = lhs_param;
1656         first_rhs_param = rhs_param;
1657       }
1658       ++parameter_count;
1659     }
1660     if (primitive_util::IsFloatingPointType(operand_types[0])) {
1661       PrimitiveType compare_type = operand_types[0];
1662       // Special-case handling for BF16. We currently do not support direct
1663       // comparisons with BF16, so we convert to F32 and then use the F32
1664       // comparison logic.
1665       if (compare_type == BF16) {
1666         compare_type = F32;
1667         first_lhs_param = b->ConvertElementType(first_lhs_param, F32);
1668         first_rhs_param = b->ConvertElementType(first_rhs_param, F32);
1669       }
1670       int64 bit_width = primitive_util::BitWidth(compare_type);
1671       first_lhs_param =
1672           BitcastConvertFloatingPointToIntegral(first_lhs_param, bit_width);
1673       first_rhs_param =
1674           BitcastConvertFloatingPointToIntegral(first_rhs_param, bit_width);
1675     }
1676     Lt(first_lhs_param, first_rhs_param);
1677 
1678     TF_ASSIGN_OR_RETURN(auto comparator, b->Build());
1679     return Sort(operands, comparator, dimension, /*is_stable=*/false);
1680   });
1681 }
1682 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)1683 XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
1684                        const XlaComputation& comparator, int64 dimension,
1685                        bool is_stable) {
1686   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1687     HloInstructionProto instr;
1688     instr.set_is_stable(is_stable);
1689     std::vector<const Shape*> operand_shape_ptrs;
1690     TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
1691                         GetOperandShapes(operands));
1692     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1693                       [](const Shape& shape) { return &shape; });
1694     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
1695                                          HloOpcode::kSort, operand_shape_ptrs));
1696     *instr.mutable_shape() = shape.ToProto();
1697     if (dimension == -1) {
1698       TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(operands[0]));
1699       dimension = keys_shape.rank() - 1;
1700     }
1701     instr.add_dimensions(dimension);
1702     AddCalledComputation(comparator, &instr);
1703     return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
1704   });
1705 }
1706 
ConvertElementType(const XlaOp & operand,PrimitiveType new_element_type)1707 XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
1708                                      PrimitiveType new_element_type) {
1709   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1710     HloInstructionProto instr;
1711     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1712     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
1713                                          operand_shape, new_element_type));
1714     *instr.mutable_shape() = shape.ToProto();
1715     return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
1716   });
1717 }
1718 
BitcastConvertType(const XlaOp & operand,PrimitiveType new_element_type)1719 XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
1720                                      PrimitiveType new_element_type) {
1721   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1722     HloInstructionProto instr;
1723     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1724     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
1725                                          operand_shape, new_element_type));
1726     *instr.mutable_shape() = shape.ToProto();
1727     return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
1728                           {operand});
1729   });
1730 }
1731 
Clamp(const XlaOp & min,const XlaOp & operand,const XlaOp & max)1732 XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand,
1733                         const XlaOp& max) {
1734   return TernaryOp(HloOpcode::kClamp, min, operand, max);
1735 }
1736 
Map(absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)1737 XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
1738                       const XlaComputation& computation,
1739                       absl::Span<const int64> dimensions,
1740                       absl::Span<const XlaOp> static_operands) {
1741   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1742     if (!static_operands.empty()) {
1743       return Unimplemented("static_operands is not supported in Map");
1744     }
1745 
1746     HloInstructionProto instr;
1747     std::vector<const Shape*> operand_shape_ptrs;
1748     TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
1749     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1750                       [](const Shape& shape) { return &shape; });
1751     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
1752                         computation.GetProgramShape());
1753     TF_ASSIGN_OR_RETURN(
1754         Shape shape, ShapeInference::InferMapShape(
1755                          operand_shape_ptrs, called_program_shape, dimensions));
1756     *instr.mutable_shape() = shape.ToProto();
1757 
1758     Shape output_shape(instr.shape());
1759     const int64 output_rank = output_shape.rank();
1760     AddCalledComputation(computation, &instr);
1761     std::vector<XlaOp> new_operands(operands.begin(), operands.end());
1762     for (XlaOp& new_operand : new_operands) {
1763       TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand));
1764       const int64 rank = shape.rank();
1765       if (rank != output_rank) {
1766         TF_ASSIGN_OR_RETURN(new_operand,
1767                             InDimBroadcast(output_shape, new_operand, {}));
1768         TF_ASSIGN_OR_RETURN(shape, GetShape(new_operand));
1769       }
1770       if (!ShapeUtil::SameDimensions(output_shape, shape)) {
1771         TF_ASSIGN_OR_RETURN(new_operand,
1772                             AddBroadcastSequence(output_shape, new_operand));
1773       }
1774     }
1775 
1776     return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
1777   });
1778 }
1779 
RngOp(RandomDistribution distribution,absl::Span<const XlaOp> parameters,const Shape & shape)1780 XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
1781                         absl::Span<const XlaOp> parameters,
1782                         const Shape& shape) {
1783   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1784     HloInstructionProto instr;
1785 
1786     // Check the number of parameters per RNG distribution.
1787     switch (distribution) {
1788       case RandomDistribution::RNG_NORMAL:
1789       case RandomDistribution::RNG_UNIFORM:
1790         if (parameters.size() != 2) {
1791           return InvalidArgument(
1792               "RNG distribution (%s) expects 2 parameters, but got %ld",
1793               RandomDistribution_Name(distribution), parameters.size());
1794         }
1795         break;
1796       default:
1797         LOG(FATAL) << "unhandled distribution " << distribution;
1798     }
1799 
1800     TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
1801     *instr.mutable_shape() = shape.ToProto();
1802 
1803     instr.set_distribution(distribution);
1804 
1805     return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
1806   });
1807 }
1808 
RngNormal(const XlaOp & mu,const XlaOp & sigma,const Shape & shape)1809 XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma,
1810                             const Shape& shape) {
1811   return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
1812 }
1813 
RngUniform(const XlaOp & a,const XlaOp & b,const Shape & shape)1814 XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b,
1815                              const Shape& shape) {
1816   return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
1817 }
1818 
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp & init)1819 XlaOp XlaBuilder::While(const XlaComputation& condition,
1820                         const XlaComputation& body, const XlaOp& init) {
1821   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1822     HloInstructionProto instr;
1823 
1824     // Infer shape.
1825     TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
1826     TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
1827                         condition.GetProgramShape());
1828     TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init));
1829     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
1830                                          condition_program_shape,
1831                                          body_program_shape, init_shape));
1832     *instr.mutable_shape() = shape.ToProto();
1833     // Body comes before condition computation in the vector.
1834     AddCalledComputation(body, &instr);
1835     AddCalledComputation(condition, &instr);
1836     return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
1837   });
1838 }
1839 
Gather(const XlaOp & input,const XlaOp & start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes)1840 XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
1841                          const GatherDimensionNumbers& dimension_numbers,
1842                          absl::Span<const int64> slice_sizes) {
1843   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1844     HloInstructionProto instr;
1845 
1846     TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
1847     TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
1848                         GetShape(start_indices));
1849     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
1850                                          input_shape, start_indices_shape,
1851                                          dimension_numbers, slice_sizes));
1852     *instr.mutable_shape() = shape.ToProto();
1853 
1854     *instr.mutable_gather_dimension_numbers() = dimension_numbers;
1855     for (int64 bound : slice_sizes) {
1856       instr.add_gather_slice_sizes(bound);
1857     }
1858 
1859     return AddInstruction(std::move(instr), HloOpcode::kGather,
1860                           {input, start_indices});
1861   });
1862 }
1863 
Scatter(const XlaOp & input,const XlaOp & scatter_indices,const XlaOp & updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers)1864 XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
1865                           const XlaOp& updates,
1866                           const XlaComputation& update_computation,
1867                           const ScatterDimensionNumbers& dimension_numbers) {
1868   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1869     HloInstructionProto instr;
1870 
1871     TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
1872     TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape,
1873                         GetShape(scatter_indices));
1874     TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates));
1875     TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
1876                         update_computation.GetProgramShape());
1877     TF_ASSIGN_OR_RETURN(Shape shape,
1878                         ShapeInference::InferScatterShape(
1879                             input_shape, scatter_indices_shape, updates_shape,
1880                             to_apply_shape, dimension_numbers));
1881     *instr.mutable_shape() = shape.ToProto();
1882 
1883     *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
1884 
1885     AddCalledComputation(update_computation, &instr);
1886     return AddInstruction(std::move(instr), HloOpcode::kScatter,
1887                           {input, scatter_indices, updates});
1888   });
1889 }
1890 
Conditional(const XlaOp & predicate,const XlaOp & true_operand,const XlaComputation & true_computation,const XlaOp & false_operand,const XlaComputation & false_computation)1891 XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
1892                               const XlaComputation& true_computation,
1893                               const XlaOp& false_operand,
1894                               const XlaComputation& false_computation) {
1895   // The index of true_computation must be 0 and that of false computation
1896   // must be 1.
1897   return Conditional(predicate, {&true_computation, &false_computation},
1898                      {true_operand, false_operand});
1899 }
1900 
Conditional(const XlaOp & branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)1901 XlaOp XlaBuilder::Conditional(
1902     const XlaOp& branch_index,
1903     absl::Span<const XlaComputation* const> branch_computations,
1904     absl::Span<const XlaOp> branch_operands) {
1905   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1906     HloInstructionProto instr;
1907 
1908     TF_ASSIGN_OR_RETURN(const Shape& branch_index_shape,
1909                         GetShape(branch_index));
1910     std::vector<Shape> branch_operand_shapes(branch_operands.size());
1911     std::vector<ProgramShape> branch_computation_shapes(
1912         branch_computations.size());
1913     for (int j = 0; j < branch_operands.size(); ++j) {
1914       TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
1915                           GetShape(branch_operands[j]));
1916       TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
1917                           branch_computations[j]->GetProgramShape());
1918     }
1919     TF_ASSIGN_OR_RETURN(const Shape shape,
1920                         ShapeInference::InferConditionalShape(
1921                             branch_index_shape, branch_computation_shapes,
1922                             branch_operand_shapes));
1923     *instr.mutable_shape() = shape.ToProto();
1924 
1925     for (const XlaComputation* branch_computation : branch_computations) {
1926       AddCalledComputation(*branch_computation, &instr);
1927     }
1928 
1929     std::vector<XlaOp> operands(1, branch_index);
1930     for (const XlaOp branch_operand : branch_operands) {
1931       operands.emplace_back(branch_operand);
1932     }
1933     return AddInstruction(std::move(instr), HloOpcode::kConditional,
1934                           absl::MakeSpan(operands));
1935   });
1936 }
1937 
Reduce(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)1938 XlaOp XlaBuilder::Reduce(const XlaOp& operand, const XlaOp& init_value,
1939                          const XlaComputation& computation,
1940                          absl::Span<const int64> dimensions_to_reduce) {
1941   return Reduce(absl::Span<const XlaOp>({operand}),
1942                 absl::Span<const XlaOp>({init_value}), computation,
1943                 dimensions_to_reduce);
1944 }
1945 
Reduce(absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)1946 XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
1947                          absl::Span<const XlaOp> init_values,
1948                          const XlaComputation& computation,
1949                          absl::Span<const int64> dimensions_to_reduce) {
1950   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1951     HloInstructionProto instr;
1952 
1953     TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
1954                         computation.GetProgramShape());
1955 
1956     std::vector<XlaOp> all_operands;
1957     all_operands.insert(all_operands.end(), operands.begin(), operands.end());
1958     all_operands.insert(all_operands.end(), init_values.begin(),
1959                         init_values.end());
1960 
1961     std::vector<const Shape*> operand_shape_ptrs;
1962     TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
1963                         GetOperandShapes(all_operands));
1964     absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
1965                       [](const Shape& shape) { return &shape; });
1966 
1967     TF_ASSIGN_OR_RETURN(
1968         Shape shape,
1969         ShapeInference::InferReduceShape(
1970             operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
1971     *instr.mutable_shape() = shape.ToProto();
1972 
1973     for (int64 dim : dimensions_to_reduce) {
1974       instr.add_dimensions(dim);
1975     }
1976 
1977     AddCalledComputation(computation, &instr);
1978 
1979     return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
1980   });
1981 }
1982 
ReduceAll(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation)1983 XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
1984                             const XlaComputation& computation) {
1985   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1986     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
1987     std::vector<int64> all_dimnos(operand_shape.rank());
1988     std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
1989     return Reduce(operand, init_value, computation, all_dimnos);
1990   });
1991 }
1992 
ReduceWindow(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)1993 XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
1994                                const XlaComputation& computation,
1995                                absl::Span<const int64> window_dimensions,
1996                                absl::Span<const int64> window_strides,
1997                                Padding padding) {
1998   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
1999     HloInstructionProto instr;
2000 
2001     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2002     TF_RETURN_IF_ERROR(
2003         ValidatePaddingValues(AsInt64Slice(operand_shape.dimensions()),
2004                               window_dimensions, window_strides));
2005 
2006     std::vector<std::pair<int64, int64>> padding_values =
2007         MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions,
2008                     window_strides, padding);
2009     return ReduceWindowWithGeneralPadding(
2010         operand, init_value, computation, window_dimensions, window_strides,
2011         /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
2012   });
2013 }
2014 
ReduceWindowWithGeneralPadding(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)2015 XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
2016     const XlaOp& operand, const XlaOp& init_value,
2017     const XlaComputation& computation,
2018     absl::Span<const int64> window_dimensions,
2019     absl::Span<const int64> window_strides,
2020     absl::Span<const int64> base_dilations,
2021     absl::Span<const int64> window_dilations,
2022     absl::Span<const std::pair<int64, int64>> padding) {
2023   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2024     HloInstructionProto instr;
2025 
2026     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2027     TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
2028     TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
2029                         computation.GetProgramShape());
2030     TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
2031                         MakeWindow(window_dimensions, window_strides, padding,
2032                                    /*lhs_dilation=*/base_dilations,
2033                                    /*rhs_dilation=*/window_dilations));
2034     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
2035                                          operand_shape, init_shape,
2036                                          instr.window(), to_apply_shape));
2037     *instr.mutable_shape() = shape.ToProto();
2038 
2039     AddCalledComputation(computation, &instr);
2040     return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
2041                           {operand, init_value});
2042   });
2043 }
2044 
BatchNormTraining(const XlaOp & operand,const XlaOp & scale,const XlaOp & offset,float epsilon,int64 feature_index)2045 XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
2046                                     const XlaOp& offset, float epsilon,
2047                                     int64 feature_index) {
2048   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2049     HloInstructionProto instr;
2050 
2051     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2052     TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale));
2053     TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset));
2054     TF_ASSIGN_OR_RETURN(
2055         Shape shape,
2056         ShapeInference::InferBatchNormTrainingShape(
2057             operand_shape, scale_shape, offset_shape, feature_index));
2058     *instr.mutable_shape() = shape.ToProto();
2059 
2060     instr.set_epsilon(epsilon);
2061     instr.set_feature_index(feature_index);
2062 
2063     return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
2064                           {operand, scale, offset});
2065   });
2066 }
2067 
BatchNormInference(const XlaOp & operand,const XlaOp & scale,const XlaOp & offset,const XlaOp & mean,const XlaOp & variance,float epsilon,int64 feature_index)2068 XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale,
2069                                      const XlaOp& offset, const XlaOp& mean,
2070                                      const XlaOp& variance, float epsilon,
2071                                      int64 feature_index) {
2072   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2073     HloInstructionProto instr;
2074 
2075     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2076     TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale));
2077     TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset));
2078     TF_ASSIGN_OR_RETURN(const Shape& mean_shape, GetShape(mean));
2079     TF_ASSIGN_OR_RETURN(const Shape& variance_shape, GetShape(variance));
2080     TF_ASSIGN_OR_RETURN(
2081         Shape shape, ShapeInference::InferBatchNormInferenceShape(
2082                          operand_shape, scale_shape, offset_shape, mean_shape,
2083                          variance_shape, feature_index));
2084     *instr.mutable_shape() = shape.ToProto();
2085 
2086     instr.set_epsilon(epsilon);
2087     instr.set_feature_index(feature_index);
2088 
2089     return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
2090                           {operand, scale, offset, mean, variance});
2091   });
2092 }
2093 
BatchNormGrad(const XlaOp & operand,const XlaOp & scale,const XlaOp & batch_mean,const XlaOp & batch_var,const XlaOp & grad_output,float epsilon,int64 feature_index)2094 XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
2095                                 const XlaOp& batch_mean, const XlaOp& batch_var,
2096                                 const XlaOp& grad_output, float epsilon,
2097                                 int64 feature_index) {
2098   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2099     HloInstructionProto instr;
2100 
2101     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2102     TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale));
2103     TF_ASSIGN_OR_RETURN(const Shape& batch_mean_shape, GetShape(batch_mean));
2104     TF_ASSIGN_OR_RETURN(const Shape& batch_var_shape, GetShape(batch_var));
2105     TF_ASSIGN_OR_RETURN(const Shape& grad_output_shape, GetShape(grad_output));
2106     TF_ASSIGN_OR_RETURN(Shape shape,
2107                         ShapeInference::InferBatchNormGradShape(
2108                             operand_shape, scale_shape, batch_mean_shape,
2109                             batch_var_shape, grad_output_shape, feature_index));
2110     *instr.mutable_shape() = shape.ToProto();
2111 
2112     instr.set_epsilon(epsilon);
2113     instr.set_feature_index(feature_index);
2114 
2115     return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
2116                           {operand, scale, batch_mean, batch_var, grad_output});
2117   });
2118 }
2119 
CrossReplicaSum(const XlaOp & operand,absl::Span<const ReplicaGroup> replica_groups)2120 XlaOp XlaBuilder::CrossReplicaSum(
2121     const XlaOp& operand, absl::Span<const ReplicaGroup> replica_groups) {
2122   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2123     TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
2124     const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
2125     auto b = CreateSubBuilder("sum");
2126     Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
2127         b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
2128     TF_ASSIGN_OR_RETURN(auto computation, b->Build());
2129     return CrossReplicaSum(operand, computation, replica_groups,
2130                            /*channel_id=*/absl::nullopt);
2131   });
2132 }
2133 
CrossReplicaSum(const XlaOp & operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id)2134 XlaOp XlaBuilder::CrossReplicaSum(
2135     const XlaOp& operand, const XlaComputation& computation,
2136     absl::Span<const ReplicaGroup> replica_groups,
2137     const absl::optional<ChannelHandle>& channel_id) {
2138   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2139     HloInstructionProto instr;
2140     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2141     TF_ASSIGN_OR_RETURN(Shape shape,
2142                         ShapeInference::InferAllReduceShape({&operand_shape}));
2143     *instr.mutable_shape() = shape.ToProto();
2144 
2145     for (const ReplicaGroup& group : replica_groups) {
2146       *instr.add_replica_groups() = group;
2147     }
2148 
2149     if (channel_id.has_value()) {
2150       instr.set_all_reduce_id(channel_id->handle());
2151     }
2152 
2153     AddCalledComputation(computation, &instr);
2154 
2155     return AddInstruction(std::move(instr), HloOpcode::kAllReduce, {operand});
2156   });
2157 }
2158 
AllToAll(const XlaOp & operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)2159 XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
2160                            int64 concat_dimension, int64 split_count,
2161                            const std::vector<ReplicaGroup>& replica_groups) {
2162   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2163     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2164 
2165     // The HloInstruction for Alltoall currently only handles the data
2166     // communication: it accepts N already split parts and scatters them to N
2167     // cores, and each core gathers the N received parts into a tuple as the
2168     // output. So here we explicitly split the operand before the hlo alltoall,
2169     // and concat the tuple elements.
2170     //
2171     // First, run shape inference to make sure the shapes are valid.
2172     TF_RETURN_IF_ERROR(
2173         ShapeInference::InferAllToAllShape(operand_shape, split_dimension,
2174                                            concat_dimension, split_count)
2175             .status());
2176 
2177     // Split into N parts.
2178     std::vector<XlaOp> slices;
2179     slices.reserve(split_count);
2180     const int64 block_size =
2181         operand_shape.dimensions(split_dimension) / split_count;
2182     for (int i = 0; i < split_count; i++) {
2183       slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
2184                                   /*limit_index=*/(i + 1) * block_size,
2185                                   /*stride=*/1, /*dimno=*/split_dimension));
2186     }
2187 
2188     // Handle data communication.
2189     HloInstructionProto instr;
2190     TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices));
2191     std::vector<const Shape*> slice_shape_ptrs;
2192     absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
2193                       [](const Shape& shape) { return &shape; });
2194     TF_ASSIGN_OR_RETURN(
2195         Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
2196     *instr.mutable_shape() = shape.ToProto();
2197     for (const ReplicaGroup& group : replica_groups) {
2198       *instr.add_replica_groups() = group;
2199     }
2200     TF_ASSIGN_OR_RETURN(
2201         XlaOp alltoall,
2202         AddInstruction(std::move(instr), HloOpcode::kAllToAll, slices));
2203 
2204     // Concat the N received parts.
2205     std::vector<XlaOp> received;
2206     received.reserve(split_count);
2207     for (int i = 0; i < split_count; i++) {
2208       received.push_back(this->GetTupleElement(alltoall, i));
2209     }
2210     return this->ConcatInDim(received, concat_dimension);
2211   });
2212 }
2213 
CollectivePermute(const XlaOp & operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)2214 XlaOp XlaBuilder::CollectivePermute(
2215     const XlaOp& operand,
2216     const std::vector<std::pair<int64, int64>>& source_target_pairs) {
2217   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2218     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2219     HloInstructionProto instr;
2220     TF_ASSIGN_OR_RETURN(
2221         Shape shape,
2222         ShapeInference::InferCollectivePermuteShape(operand_shape));
2223     *instr.mutable_shape() = shape.ToProto();
2224 
2225     for (const auto& pair : source_target_pairs) {
2226       auto* proto_pair = instr.add_source_target_pairs();
2227       proto_pair->set_source(pair.first);
2228       proto_pair->set_target(pair.second);
2229     }
2230 
2231     return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
2232                           {operand});
2233   });
2234 }
2235 
ReplicaId()2236 XlaOp XlaBuilder::ReplicaId() {
2237   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2238     HloInstructionProto instr;
2239     *instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
2240     return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
2241   });
2242 }
2243 
SelectAndScatter(const XlaOp & operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp & source,const XlaOp & init_value,const XlaComputation & scatter)2244 XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand,
2245                                    const XlaComputation& select,
2246                                    absl::Span<const int64> window_dimensions,
2247                                    absl::Span<const int64> window_strides,
2248                                    Padding padding, const XlaOp& source,
2249                                    const XlaOp& init_value,
2250                                    const XlaComputation& scatter) {
2251   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2252     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2253     return SelectAndScatterWithGeneralPadding(
2254         operand, select, window_dimensions, window_strides,
2255         MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions,
2256                     window_strides, padding),
2257         source, init_value, scatter);
2258   });
2259 }
2260 
SelectAndScatterWithGeneralPadding(const XlaOp & operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp & source,const XlaOp & init_value,const XlaComputation & scatter)2261 XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
2262     const XlaOp& operand, const XlaComputation& select,
2263     absl::Span<const int64> window_dimensions,
2264     absl::Span<const int64> window_strides,
2265     absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
2266     const XlaOp& init_value, const XlaComputation& scatter) {
2267   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2268     HloInstructionProto instr;
2269 
2270     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2271     TF_ASSIGN_OR_RETURN(const Shape& source_shape, GetShape(source));
2272     TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
2273     TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
2274                         select.GetProgramShape());
2275     TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
2276                         scatter.GetProgramShape());
2277     TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
2278                         MakeWindow(window_dimensions, window_strides, padding,
2279                                    /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
2280     TF_ASSIGN_OR_RETURN(Shape shape,
2281                         ShapeInference::InferSelectAndScatterShape(
2282                             operand_shape, select_shape, instr.window(),
2283                             source_shape, init_shape, scatter_shape));
2284     *instr.mutable_shape() = shape.ToProto();
2285 
2286     AddCalledComputation(select, &instr);
2287     AddCalledComputation(scatter, &instr);
2288 
2289     return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
2290                           {operand, source, init_value});
2291   });
2292 }
2293 
ReducePrecision(const XlaOp & operand,const int exponent_bits,const int mantissa_bits)2294 XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
2295                                   const int mantissa_bits) {
2296   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2297     HloInstructionProto instr;
2298     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2299     TF_ASSIGN_OR_RETURN(Shape shape,
2300                         ShapeInference::InferReducePrecisionShape(
2301                             operand_shape, exponent_bits, mantissa_bits));
2302     *instr.mutable_shape() = shape.ToProto();
2303     instr.set_exponent_bits(exponent_bits);
2304     instr.set_mantissa_bits(mantissa_bits);
2305     return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
2306                           {operand});
2307   });
2308 }
2309 
Send(const XlaOp & operand,const ChannelHandle & handle)2310 void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
2311   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2312     // Send HLO takes two operands: a data operand and a token. Generate the
2313     // token to pass into the send.
2314     // TODO(b/80000000): Remove this when clients have been updated to handle
2315     // tokens.
2316     HloInstructionProto token_instr;
2317     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2318     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
2319                                                     HloOpcode::kAfterAll, {}));
2320 
2321     return SendWithToken(operand, token, handle);
2322   });
2323 }
2324 
SendWithToken(const XlaOp & operand,const XlaOp & token,const ChannelHandle & handle)2325 XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token,
2326                                 const ChannelHandle& handle) {
2327   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2328     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
2329       return InvalidArgument("Send must use a device-to-device channel");
2330     }
2331 
2332     // Send instruction produces a tuple of {aliased operand, U32 context,
2333     // token}.
2334     HloInstructionProto send_instr;
2335     TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
2336     *send_instr.mutable_shape() =
2337         ShapeUtil::MakeTupleShape(
2338             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2339             .ToProto();
2340     send_instr.set_channel_id(handle.handle());
2341     TF_ASSIGN_OR_RETURN(XlaOp send,
2342                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
2343                                        {operand, token}));
2344 
2345     HloInstructionProto send_done_instr;
2346     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2347     send_done_instr.set_channel_id(handle.handle());
2348     return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
2349                           {send});
2350   });
2351 }
2352 
Recv(const Shape & shape,const ChannelHandle & handle)2353 XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
2354   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2355     // Recv HLO takes a single token operand. Generate the token to pass into
2356     // the Recv and RecvDone instructions.
2357     // TODO(b/80000000): Remove this when clients have been updated to handle
2358     // tokens.
2359     HloInstructionProto token_instr;
2360     *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2361     TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
2362                                                     HloOpcode::kAfterAll, {}));
2363 
2364     XlaOp recv = RecvWithToken(token, shape, handle);
2365 
2366     // The RecvDone instruction produces a tuple of the data and a token
2367     // type. Return XLA op containing the data.
2368     // TODO(b/80000000): Remove this when clients have been updated to handle
2369     // tokens.
2370     HloInstructionProto recv_data;
2371     *recv_data.mutable_shape() = shape.ToProto();
2372     recv_data.set_tuple_index(0);
2373     return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
2374                           {recv});
2375   });
2376 }
2377 
RecvWithToken(const XlaOp & token,const Shape & shape,const ChannelHandle & handle)2378 XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
2379                                 const ChannelHandle& handle) {
2380   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2381     if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
2382       return InvalidArgument("Recv must use a device-to-device channel");
2383     }
2384 
2385     // Recv instruction produces a tuple of {receive buffer, U32 context,
2386     // token}.
2387     HloInstructionProto recv_instr;
2388     *recv_instr.mutable_shape() =
2389         ShapeUtil::MakeTupleShape(
2390             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2391             .ToProto();
2392     recv_instr.set_channel_id(handle.handle());
2393     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
2394                                                    HloOpcode::kRecv, {token}));
2395 
2396     HloInstructionProto recv_done_instr;
2397     *recv_done_instr.mutable_shape() =
2398         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
2399             .ToProto();
2400     recv_done_instr.set_channel_id(handle.handle());
2401     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
2402                           {recv});
2403   });
2404 }
2405 
SendToHost(const XlaOp & operand,const XlaOp & token,const Shape & shape_with_layout,const ChannelHandle & handle)2406 XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token,
2407                              const Shape& shape_with_layout,
2408                              const ChannelHandle& handle) {
2409   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2410     if (!LayoutUtil::HasLayout(shape_with_layout)) {
2411       return InvalidArgument("Shape passed to SendToHost must have a layout");
2412     }
2413     TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
2414     if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
2415       return InvalidArgument(
2416           "SendToHost shape %s must be compatible with operand shape %s",
2417           ShapeUtil::HumanStringWithLayout(shape_with_layout),
2418           ShapeUtil::HumanStringWithLayout(operand_shape));
2419     }
2420     // TODO(b/111544877): Support tuple shapes.
2421     if (!operand_shape.IsArray()) {
2422       return InvalidArgument("SendToHost only supports array shapes, shape: %s",
2423                              ShapeUtil::HumanString(operand_shape));
2424     }
2425 
2426     if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
2427       return InvalidArgument("SendToHost must use a device-to-host channel");
2428     }
2429 
2430     // Send instruction produces a tuple of {aliased operand, U32 context,
2431     // token}.
2432     HloInstructionProto send_instr;
2433     *send_instr.mutable_shape() =
2434         ShapeUtil::MakeTupleShape({shape_with_layout,
2435                                    ShapeUtil::MakeShape(U32, {}),
2436                                    ShapeUtil::MakeTokenShape()})
2437             .ToProto();
2438     send_instr.set_channel_id(handle.handle());
2439     send_instr.set_is_host_transfer(true);
2440     TF_ASSIGN_OR_RETURN(XlaOp send,
2441                         AddInstruction(std::move(send_instr), HloOpcode::kSend,
2442                                        {operand, token}));
2443 
2444     HloInstructionProto send_done_instr;
2445     *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
2446     send_done_instr.set_channel_id(handle.handle());
2447     send_done_instr.set_is_host_transfer(true);
2448     return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
2449                           {send});
2450   });
2451 }
2452 
RecvFromHost(const XlaOp & token,const Shape & shape,const ChannelHandle & handle)2453 XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape,
2454                                const ChannelHandle& handle) {
2455   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2456     if (!LayoutUtil::HasLayout(shape)) {
2457       return InvalidArgument("Shape passed to RecvFromHost must have a layout");
2458     }
2459 
2460     // TODO(b/111544877): Support tuple shapes.
2461     if (!shape.IsArray()) {
2462       return InvalidArgument(
2463           "RecvFromHost only supports array shapes, shape: %s",
2464           ShapeUtil::HumanString(shape));
2465     }
2466 
2467     if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
2468       return InvalidArgument("RecvFromHost must use a host-to-device channel");
2469     }
2470 
2471     // Recv instruction produces a tuple of {receive buffer, U32 context,
2472     // token}.
2473     HloInstructionProto recv_instr;
2474     *recv_instr.mutable_shape() =
2475         ShapeUtil::MakeTupleShape(
2476             {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
2477             .ToProto();
2478     recv_instr.set_channel_id(handle.handle());
2479     recv_instr.set_is_host_transfer(true);
2480     TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
2481                                                    HloOpcode::kRecv, {token}));
2482 
2483     HloInstructionProto recv_done_instr;
2484     *recv_done_instr.mutable_shape() =
2485         ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
2486             .ToProto();
2487     recv_done_instr.set_channel_id(handle.handle());
2488     recv_done_instr.set_is_host_transfer(true);
2489     return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
2490                           {recv});
2491   });
2492 }
2493 
GetDimensionSize(const XlaOp & operand,int64 dimension)2494 XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) {
2495   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
2496     HloInstructionProto instr;
2497     TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand));
2498     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
2499                                          operand_shape, dimension));
2500     *instr.mutable_shape() = shape.ToProto();
2501     instr.add_dimensions(dimension);
2502     return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
2503                           {operand});
2504   });
2505 }
2506 
IsConstant(const XlaOp & operand) const2507 StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
2508   TF_RETURN_IF_ERROR(first_error_);
2509 
2510   // Verify that the handle is valid.
2511   TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
2512 
2513   bool is_constant = true;
2514   absl::flat_hash_set<int64> visited;
2515   IsConstantVisitor(operand.handle(), &visited, &is_constant);
2516   return is_constant;
2517 }
2518 
BuildConstantSubGraph(const XlaOp & root_op)2519 StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
2520     const XlaOp& root_op) {
2521   TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
2522   if (!is_constant) {
2523     auto op_status = LookUpInstruction(root_op);
2524     string op_string =
2525         op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
2526     return InvalidArgument(
2527         "Operand to BuildConstantSubGraph depends on a parameter.\n\n"
2528         "  op requested for constant subgraph: %s\n\n"
2529         "This is an internal error that typically happens when the XLA user "
2530         "(e.g. TensorFlow) is attempting to determine a value that must be a "
2531         "compile-time constant (e.g. an array dimension) but it is not capable "
2532         "of being evaluated at XLA compile time.\n\n"
2533         "Please file a usability bug with the framework being used (e.g. "
2534         "TensorFlow).",
2535         op_string);
2536   }
2537 
2538   TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
2539                       LookUpInstruction(root_op));
2540 
2541   HloComputationProto entry;
2542   SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
2543                     GetNextId());
2544   entry.set_root_id(root->id());
2545   ProgramShapeProto* program_shape = entry.mutable_program_shape();
2546   *program_shape->mutable_result() = root->shape();
2547 
2548   // We use std::set to keep the instruction ids in ascending order (which is
2549   // also a valid dependency order). The related ops will be added to the
2550   // subgraph in the same order.
2551   std::set<int64> related_ops;
2552   absl::flat_hash_set<int64> related_calls;  // Related computations.
2553   std::queue<int64> worklist;
2554   worklist.push(root->id());
2555   related_ops.insert(root->id());
2556   while (!worklist.empty()) {
2557     int64 handle = worklist.front();
2558     worklist.pop();
2559     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
2560                         LookUpInstructionByHandle(handle));
2561 
2562     if (instr_proto->opcode() ==
2563         HloOpcodeString(HloOpcode::kGetDimensionSize)) {
2564       // At this point, BuildConstantSubGraph should never encounter a
2565       // GetDimensionSize with a dynamic dimension. IsConstant check would have
2566       // failed at the beginning of this function.
2567       //
2568       // Replace GetDimensionSize with a Constant representing the static bound
2569       // of the shape.
2570       int64 dimension = instr_proto->dimensions(0);
2571       int64 operand_handle = instr_proto->operand_ids(0);
2572       TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
2573                           LookUpInstructionByHandle(operand_handle));
2574 
2575       TF_RET_CHECK(!operand_proto->shape().is_dynamic_dimension(dimension));
2576       auto constant_dimension_size =
2577           static_cast<uint32>(operand_proto->shape().dimensions(dimension));
2578 
2579       Literal literal = LiteralUtil::CreateR0(constant_dimension_size);
2580 
2581       HloInstructionProto const_instr;
2582       *const_instr.mutable_shape() = literal.shape().ToProto();
2583       *const_instr.mutable_literal() = literal.ToProto();
2584       *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
2585 
2586       const_instr.set_id(handle);
2587       *const_instr.mutable_name() =
2588           GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
2589       *entry.add_instructions() =
2590           const_instr;  // Add to the result constant graph.
2591     } else {
2592       for (int64 id : instr_proto->operand_ids()) {
2593         if (related_ops.insert(id).second) {
2594           worklist.push(id);
2595         }
2596       }
2597       for (int64 called_id : instr_proto->called_computation_ids()) {
2598         related_calls.insert(called_id);
2599       }
2600     }
2601   }
2602 
2603   // Add related ops to the computation.
2604   for (int64 id : related_ops) {
2605     TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
2606                         LookUpInstructionByHandle(id));
2607 
2608     if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize)) {
2609       continue;
2610     }
2611     auto* instr = entry.add_instructions();
2612 
2613     *instr = *instr_src;
2614     // Ensures that the instruction names are unique among the graph.
2615     const string& new_name =
2616         StrCat(instr->name(), ".", entry.id(), ".", instr->id());
2617     instr->set_name(new_name);
2618   }
2619 
2620   XlaComputation computation(entry.id());
2621   HloModuleProto* module = computation.mutable_proto();
2622   module->set_name(entry.name());
2623   module->set_id(entry.id());
2624   module->set_entry_computation_name(entry.name());
2625   module->set_entry_computation_id(entry.id());
2626   *module->mutable_host_program_shape() = *program_shape;
2627   for (auto& e : embedded_) {
2628     if (related_calls.find(e.second.id()) != related_calls.end()) {
2629       *module->add_computations() = e.second;
2630     }
2631   }
2632   *module->add_computations() = std::move(entry);
2633 
2634   return std::move(computation);
2635 }
2636 
CreateSubBuilder(const string & computation_name)2637 std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
2638     const string& computation_name) {
2639   auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
2640   sub_builder->parent_builder_ = this;
2641   sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
2642   return sub_builder;
2643 }
2644 
2645 /* static */ ConvolutionDimensionNumbers
CreateDefaultConvDimensionNumbers(int num_spatial_dims)2646 XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
2647   ConvolutionDimensionNumbers dimension_numbers;
2648   dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
2649   dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
2650   dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
2651   dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
2652   dimension_numbers.set_kernel_output_feature_dimension(
2653       kConvKernelOutputDimension);
2654   dimension_numbers.set_kernel_input_feature_dimension(
2655       kConvKernelInputDimension);
2656   for (int i = 0; i < num_spatial_dims; ++i) {
2657     dimension_numbers.add_input_spatial_dimensions(i + 2);
2658     dimension_numbers.add_kernel_spatial_dimensions(i + 2);
2659     dimension_numbers.add_output_spatial_dimensions(i + 2);
2660   }
2661   return dimension_numbers;
2662 }
2663 
Validate(const ConvolutionDimensionNumbers & dnum)2664 /* static */ Status XlaBuilder::Validate(
2665     const ConvolutionDimensionNumbers& dnum) {
2666   if (dnum.input_spatial_dimensions_size() < 2) {
2667     return FailedPrecondition("input spacial dimension < 2: %d",
2668                               dnum.input_spatial_dimensions_size());
2669   }
2670   if (dnum.kernel_spatial_dimensions_size() < 2) {
2671     return FailedPrecondition("kernel spacial dimension < 2: %d",
2672                               dnum.kernel_spatial_dimensions_size());
2673   }
2674   if (dnum.output_spatial_dimensions_size() < 2) {
2675     return FailedPrecondition("output spacial dimension < 2: %d",
2676                               dnum.output_spatial_dimensions_size());
2677   }
2678 
2679   if (std::set<int64>(
2680           {dnum.input_batch_dimension(), dnum.input_feature_dimension(),
2681            dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
2682           .size() != 4) {
2683     return FailedPrecondition(
2684         "dimension numbers for the input are not unique: (%d, %d, %d, "
2685         "%d)",
2686         dnum.input_batch_dimension(), dnum.input_feature_dimension(),
2687         dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
2688   }
2689   if (std::set<int64>({dnum.kernel_output_feature_dimension(),
2690                        dnum.kernel_input_feature_dimension(),
2691                        dnum.kernel_spatial_dimensions(0),
2692                        dnum.kernel_spatial_dimensions(1)})
2693           .size() != 4) {
2694     return FailedPrecondition(
2695         "dimension numbers for the weight are not unique: (%d, %d, %d, "
2696         "%d)",
2697         dnum.kernel_output_feature_dimension(),
2698         dnum.kernel_input_feature_dimension(),
2699         dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
2700   }
2701   if (std::set<int64>({dnum.output_batch_dimension(),
2702                        dnum.output_feature_dimension(),
2703                        dnum.output_spatial_dimensions(0),
2704                        dnum.output_spatial_dimensions(1)})
2705           .size() != 4) {
2706     return FailedPrecondition(
2707         "dimension numbers for the output are not unique: (%d, %d, %d, "
2708         "%d)",
2709         dnum.output_batch_dimension(), dnum.output_feature_dimension(),
2710         dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
2711   }
2712   return Status::OK();
2713 }
2714 
AddInstruction(HloInstructionProto && instr,HloOpcode opcode,absl::Span<const XlaOp> operands)2715 StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
2716                                            HloOpcode opcode,
2717                                            absl::Span<const XlaOp> operands) {
2718   TF_RETURN_IF_ERROR(first_error_);
2719 
2720   const int64 handle = GetNextId();
2721   instr.set_id(handle);
2722   instr.set_opcode(HloOpcodeString(opcode));
2723   if (instr.name().empty()) {
2724     instr.set_name(instr.opcode());
2725   }
2726   for (const auto& operand : operands) {
2727     if (operand.builder_ == nullptr) {
2728       return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
2729     }
2730     if (operand.builder_ != this) {
2731       return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
2732                              operand.builder_->name(), this->name());
2733     }
2734     instr.add_operand_ids(operand.handle());
2735   }
2736 
2737   *instr.mutable_metadata() = metadata_;
2738   if (sharding_) {
2739     *instr.mutable_sharding() = *sharding_;
2740   }
2741 
2742   handle_to_index_[handle] = instructions_.size();
2743   instructions_.push_back(std::move(instr));
2744 
2745   XlaOp op(handle, this);
2746   return op;
2747 }
2748 
AddCalledComputation(const XlaComputation & computation,HloInstructionProto * instr)2749 void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
2750                                       HloInstructionProto* instr) {
2751   absl::flat_hash_map<int64, int64> remapped_ids;
2752   std::vector<HloComputationProto> imported_computations;
2753   imported_computations.reserve(computation.proto().computations_size());
2754   // Before we import the computations by remapping IDs, and capturing the
2755   // old->new mappings in remapped_ids.
2756   for (const HloComputationProto& e : computation.proto().computations()) {
2757     HloComputationProto new_computation(e);
2758     int64 computation_id = GetNextId();
2759     remapped_ids[new_computation.id()] = computation_id;
2760     SetProtoIdAndName(&new_computation,
2761                       GetBaseName(new_computation.name(), kNameSeparator),
2762                       kNameSeparator, computation_id);
2763     for (auto& instruction : *new_computation.mutable_instructions()) {
2764       int64 instruction_id = GetNextId();
2765       remapped_ids[instruction.id()] = instruction_id;
2766       SetProtoIdAndName(&instruction,
2767                         GetBaseName(instruction.name(), kNameSeparator),
2768                         kNameSeparator, instruction_id);
2769     }
2770     new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
2771 
2772     imported_computations.push_back(std::move(new_computation));
2773   }
2774   // Once we have imported all the computations, and captured all the ID
2775   // mappings, we go back and fixup the IDs in the imported computations.
2776   instr->add_called_computation_ids(
2777       remapped_ids.at(computation.proto().entry_computation_id()));
2778   for (auto& imported_computation : imported_computations) {
2779     for (auto& instruction : *imported_computation.mutable_instructions()) {
2780       for (auto& operand_id : *instruction.mutable_operand_ids()) {
2781         operand_id = remapped_ids.at(operand_id);
2782       }
2783       for (auto& control_predecessor_id :
2784            *instruction.mutable_control_predecessor_ids()) {
2785         control_predecessor_id = remapped_ids.at(control_predecessor_id);
2786       }
2787       for (auto& called_computation_id :
2788            *instruction.mutable_called_computation_ids()) {
2789         called_computation_id = remapped_ids.at(called_computation_id);
2790       }
2791     }
2792 
2793     int64 computation_id = imported_computation.id();
2794     embedded_.insert({computation_id, std::move(imported_computation)});
2795   }
2796 }
2797 
LookUpInstruction(const XlaOp & op) const2798 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
2799     const XlaOp& op) const {
2800   TF_RETURN_IF_ERROR(first_error_);
2801 
2802   if (op.builder_ == nullptr) {
2803     return InvalidArgument(
2804         "invalid XlaOp with handle %d; the builder of this op is freed",
2805         op.handle());
2806   }
2807   if (op.builder_ != this) {
2808     return InvalidArgument(
2809         "XlaOp with handle %d is built by builder '%s', but is trying to use "
2810         "it in builder '%s'",
2811         op.handle(), op.builder_->name(), this->name());
2812   }
2813 
2814   return LookUpInstructionByHandle(op.handle());
2815 }
2816 
LookUpInstructionByHandle(int64 handle) const2817 StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
2818     int64 handle) const {
2819   auto it = handle_to_index_.find(handle);
2820   if (it == handle_to_index_.end()) {
2821     return InvalidArgument("No XlaOp with handle %d", handle);
2822   }
2823   return &instructions_[it->second];
2824 }
2825 
2826 // Enqueues a "retrieve parameter value" instruction for a parameter that was
2827 // passed to the computation.
Parameter(XlaBuilder * builder,int64 parameter_number,const Shape & shape,const string & name)2828 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
2829                 const string& name) {
2830   return builder->Parameter(parameter_number, shape, name);
2831 }
2832 
2833 // Enqueues a constant with the value of the given literal onto the
2834 // computation.
ConstantLiteral(XlaBuilder * builder,const LiteralSlice & literal)2835 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
2836   return builder->ConstantLiteral(literal);
2837 }
2838 
Broadcast(const XlaOp & operand,absl::Span<const int64> broadcast_sizes)2839 XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes) {
2840   return operand.builder()->Broadcast(operand, broadcast_sizes);
2841 }
2842 
BroadcastInDim(const XlaOp & operand,const absl::Span<const int64> out_dim_size,const absl::Span<const int64> broadcast_dimensions)2843 XlaOp BroadcastInDim(const XlaOp& operand,
2844                      const absl::Span<const int64> out_dim_size,
2845                      const absl::Span<const int64> broadcast_dimensions) {
2846   return operand.builder()->BroadcastInDim(operand, out_dim_size,
2847                                            broadcast_dimensions);
2848 }
2849 
Pad(const XlaOp & operand,const XlaOp & padding_value,const PaddingConfig & padding_config)2850 XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
2851           const PaddingConfig& padding_config) {
2852   return operand.builder()->Pad(operand, padding_value, padding_config);
2853 }
2854 
Reshape(const XlaOp & operand,absl::Span<const int64> dimensions,absl::Span<const int64> new_sizes)2855 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
2856               absl::Span<const int64> new_sizes) {
2857   return operand.builder()->Reshape(operand, dimensions, new_sizes);
2858 }
2859 
Reshape(const XlaOp & operand,absl::Span<const int64> new_sizes)2860 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes) {
2861   return operand.builder()->Reshape(operand, new_sizes);
2862 }
2863 
Collapse(const XlaOp & operand,absl::Span<const int64> dimensions)2864 XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions) {
2865   return operand.builder()->Collapse(operand, dimensions);
2866 }
2867 
Slice(const XlaOp & operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)2868 XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
2869             absl::Span<const int64> limit_indices,
2870             absl::Span<const int64> strides) {
2871   return operand.builder()->Slice(operand, start_indices, limit_indices,
2872                                   strides);
2873 }
2874 
SliceInDim(const XlaOp & operand,int64 start_index,int64 limit_index,int64 stride,int64 dimno)2875 XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
2876                  int64 stride, int64 dimno) {
2877   return operand.builder()->SliceInDim(operand, start_index, limit_index,
2878                                        stride, dimno);
2879 }
2880 
DynamicSlice(const XlaOp & operand,const XlaOp & start_indices,absl::Span<const int64> slice_sizes)2881 XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
2882                    absl::Span<const int64> slice_sizes) {
2883   return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
2884 }
DynamicSlice(const XlaOp & operand,absl::Span<const XlaOp> start_indices,absl::Span<const int64> slice_sizes)2885 XlaOp DynamicSlice(const XlaOp& operand, absl::Span<const XlaOp> start_indices,
2886                    absl::Span<const int64> slice_sizes) {
2887   return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
2888 }
2889 
DynamicUpdateSlice(const XlaOp & operand,const XlaOp & update,const XlaOp & start_indices)2890 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
2891                          const XlaOp& start_indices) {
2892   return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
2893 }
2894 
DynamicUpdateSlice(const XlaOp & operand,const XlaOp & update,absl::Span<const XlaOp> start_indices)2895 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
2896                          absl::Span<const XlaOp> start_indices) {
2897   return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
2898 }
2899 
ConcatInDim(XlaBuilder * builder,absl::Span<const XlaOp> operands,int64 dimension)2900 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2901                   int64 dimension) {
2902   return builder->ConcatInDim(operands, dimension);
2903 }
2904 
Trace(const string & tag,const XlaOp & operand)2905 void Trace(const string& tag, const XlaOp& operand) {
2906   return operand.builder()->Trace(tag, operand);
2907 }
2908 
Select(const XlaOp & pred,const XlaOp & on_true,const XlaOp & on_false)2909 XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) {
2910   return pred.builder()->Select(pred, on_true, on_false);
2911 }
2912 
Tuple(XlaBuilder * builder,absl::Span<const XlaOp> elements)2913 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
2914   return builder->Tuple(elements);
2915 }
2916 
GetTupleElement(const XlaOp & tuple_data,int64 index)2917 XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
2918   return tuple_data.builder()->GetTupleElement(tuple_data, index);
2919 }
2920 
Eq(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2921 XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
2922          absl::Span<const int64> broadcast_dimensions) {
2923   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
2924 }
2925 
Ne(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2926 XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
2927          absl::Span<const int64> broadcast_dimensions) {
2928   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
2929 }
2930 
Ge(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2931 XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
2932          absl::Span<const int64> broadcast_dimensions) {
2933   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
2934 }
2935 
Gt(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2936 XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
2937          absl::Span<const int64> broadcast_dimensions) {
2938   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
2939 }
2940 
Le(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2941 XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
2942          absl::Span<const int64> broadcast_dimensions) {
2943   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
2944 }
2945 
Lt(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)2946 XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
2947          absl::Span<const int64> broadcast_dimensions) {
2948   return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
2949 }
2950 
Compare(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions,ComparisonDirection direction)2951 XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
2952               absl::Span<const int64> broadcast_dimensions,
2953               ComparisonDirection direction) {
2954   return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
2955                                  broadcast_dimensions, direction);
2956 }
2957 
Dot(const XlaOp & lhs,const XlaOp & rhs,const PrecisionConfig * precision_config)2958 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
2959           const PrecisionConfig* precision_config) {
2960   return lhs.builder()->Dot(lhs, rhs, precision_config);
2961 }
2962 
DotGeneral(const XlaOp & lhs,const XlaOp & rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig * precision_config)2963 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
2964                  const DotDimensionNumbers& dimension_numbers,
2965                  const PrecisionConfig* precision_config) {
2966   return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
2967                                    precision_config);
2968 }
2969 
Conv(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,Padding padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)2970 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
2971            absl::Span<const int64> window_strides, Padding padding,
2972            int64 feature_group_count, int64 batch_group_count,
2973            const PrecisionConfig* precision_config) {
2974   return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
2975                              feature_group_count, batch_group_count,
2976                              precision_config);
2977 }
2978 
ConvWithGeneralPadding(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)2979 XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
2980                              absl::Span<const int64> window_strides,
2981                              absl::Span<const std::pair<int64, int64>> padding,
2982                              int64 feature_group_count, int64 batch_group_count,
2983                              const PrecisionConfig* precision_config) {
2984   return lhs.builder()->ConvWithGeneralPadding(
2985       lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
2986       precision_config);
2987 }
2988 
ConvWithGeneralDimensions(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,Padding padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)2989 XlaOp ConvWithGeneralDimensions(
2990     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
2991     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
2992     int64 feature_group_count, int64 batch_group_count,
2993     const PrecisionConfig* precision_config) {
2994   return lhs.builder()->ConvWithGeneralDimensions(
2995       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
2996       batch_group_count, precision_config);
2997 }
2998 
ConvGeneral(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)2999 XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
3000                   absl::Span<const int64> window_strides,
3001                   absl::Span<const std::pair<int64, int64>> padding,
3002                   const ConvolutionDimensionNumbers& dimension_numbers,
3003                   int64 feature_group_count, int64 batch_group_count,
3004                   const PrecisionConfig* precision_config) {
3005   return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
3006                                     dimension_numbers, feature_group_count,
3007                                     batch_group_count, precision_config);
3008 }
3009 
ConvGeneralDilated(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,absl::Span<const int64> lhs_dilation,absl::Span<const int64> rhs_dilation,const ConvolutionDimensionNumbers & dimension_numbers,int64 feature_group_count,int64 batch_group_count,const PrecisionConfig * precision_config)3010 XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
3011                          absl::Span<const int64> window_strides,
3012                          absl::Span<const std::pair<int64, int64>> padding,
3013                          absl::Span<const int64> lhs_dilation,
3014                          absl::Span<const int64> rhs_dilation,
3015                          const ConvolutionDimensionNumbers& dimension_numbers,
3016                          int64 feature_group_count, int64 batch_group_count,
3017                          const PrecisionConfig* precision_config) {
3018   return lhs.builder()->ConvGeneralDilated(
3019       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
3020       dimension_numbers, feature_group_count, batch_group_count,
3021       precision_config);
3022 }
3023 
Fft(const XlaOp & operand,FftType fft_type,absl::Span<const int64> fft_length)3024 XlaOp Fft(const XlaOp& operand, FftType fft_type,
3025           absl::Span<const int64> fft_length) {
3026   return operand.builder()->Fft(operand, fft_type, fft_length);
3027 }
3028 
TriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool unit_diagonal,TriangularSolveOptions::Transpose transpose_a)3029 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
3030                       bool unit_diagonal,
3031                       TriangularSolveOptions::Transpose transpose_a) {
3032   XlaBuilder* builder = a.builder();
3033   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3034     HloInstructionProto instr;
3035     TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a));
3036     TF_ASSIGN_OR_RETURN(const Shape& b_shape, builder->GetShape(b));
3037     xla::TriangularSolveOptions& options =
3038         *instr.mutable_triangular_solve_options();
3039     options.set_left_side(left_side);
3040     options.set_lower(lower);
3041     options.set_unit_diagonal(unit_diagonal);
3042     options.set_transpose_a(transpose_a);
3043     TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
3044                                          a_shape, b_shape, options));
3045     *instr.mutable_shape() = shape.ToProto();
3046 
3047     return builder->AddInstruction(std::move(instr),
3048                                    HloOpcode::kTriangularSolve, {a, b});
3049   });
3050 }
3051 
Cholesky(XlaOp a,bool lower)3052 XlaOp Cholesky(XlaOp a, bool lower) {
3053   XlaBuilder* builder = a.builder();
3054   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
3055     HloInstructionProto instr;
3056     TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a));
3057     xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
3058     options.set_lower(lower);
3059     TF_ASSIGN_OR_RETURN(Shape shape,
3060                         ShapeInference::InferCholeskyShape(a_shape));
3061     *instr.mutable_shape() = shape.ToProto();
3062 
3063     return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
3064   });
3065 }
3066 
Infeed(XlaBuilder * builder,const Shape & shape,const string & config)3067 XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
3068   return builder->Infeed(shape, config);
3069 }
3070 
Outfeed(const XlaOp & operand,const Shape & shape_with_layout,const string & outfeed_config)3071 void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
3072              const string& outfeed_config) {
3073   return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
3074 }
3075 
Call(XlaBuilder * builder,const XlaComputation & computation,absl::Span<const XlaOp> operands)3076 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
3077            absl::Span<const XlaOp> operands) {
3078   return builder->Call(computation, operands);
3079 }
3080 
CustomCall(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,const string & opaque)3081 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
3082                  absl::Span<const XlaOp> operands, const Shape& shape,
3083                  const string& opaque) {
3084   return builder->CustomCall(call_target_name, operands, shape, opaque,
3085                              /*operand_shapes_with_layout=*/absl::nullopt);
3086 }
3087 
CustomCallWithLayout(XlaBuilder * builder,const string & call_target_name,absl::Span<const XlaOp> operands,const Shape & shape,absl::Span<const Shape> operand_shapes_with_layout,const string & opaque)3088 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
3089                            absl::Span<const XlaOp> operands, const Shape& shape,
3090                            absl::Span<const Shape> operand_shapes_with_layout,
3091                            const string& opaque) {
3092   return builder->CustomCall(call_target_name, operands, shape, opaque,
3093                              operand_shapes_with_layout);
3094 }
3095 
Complex(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3096 XlaOp Complex(const XlaOp& lhs, const XlaOp& rhs,
3097               absl::Span<const int64> broadcast_dimensions) {
3098   return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
3099                                  broadcast_dimensions);
3100 }
3101 
Conj(const XlaOp & operand)3102 XlaOp Conj(const XlaOp& operand) {
3103   return Complex(Real(operand), Neg(Imag(operand)));
3104 }
3105 
Add(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3106 XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
3107           absl::Span<const int64> broadcast_dimensions) {
3108   return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
3109                                  broadcast_dimensions);
3110 }
3111 
Sub(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3112 XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
3113           absl::Span<const int64> broadcast_dimensions) {
3114   return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
3115                                  broadcast_dimensions);
3116 }
3117 
Mul(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3118 XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
3119           absl::Span<const int64> broadcast_dimensions) {
3120   return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
3121                                  broadcast_dimensions);
3122 }
3123 
Div(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3124 XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
3125           absl::Span<const int64> broadcast_dimensions) {
3126   return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
3127                                  broadcast_dimensions);
3128 }
3129 
Rem(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3130 XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
3131           absl::Span<const int64> broadcast_dimensions) {
3132   return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
3133                                  broadcast_dimensions);
3134 }
3135 
Max(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3136 XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
3137           absl::Span<const int64> broadcast_dimensions) {
3138   return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
3139                                  broadcast_dimensions);
3140 }
3141 
Min(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3142 XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
3143           absl::Span<const int64> broadcast_dimensions) {
3144   return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
3145                                  broadcast_dimensions);
3146 }
3147 
And(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3148 XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
3149           absl::Span<const int64> broadcast_dimensions) {
3150   return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
3151                                  broadcast_dimensions);
3152 }
3153 
Or(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3154 XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
3155          absl::Span<const int64> broadcast_dimensions) {
3156   return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
3157                                  broadcast_dimensions);
3158 }
3159 
Xor(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3160 XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
3161           absl::Span<const int64> broadcast_dimensions) {
3162   return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
3163                                  broadcast_dimensions);
3164 }
3165 
Not(const XlaOp & operand)3166 XlaOp Not(const XlaOp& operand) {
3167   return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
3168 }
3169 
ShiftLeft(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3170 XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
3171                 absl::Span<const int64> broadcast_dimensions) {
3172   return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
3173                                  broadcast_dimensions);
3174 }
3175 
ShiftRightArithmetic(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3176 XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
3177                            absl::Span<const int64> broadcast_dimensions) {
3178   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
3179                                  broadcast_dimensions);
3180 }
3181 
ShiftRightLogical(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3182 XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
3183                         absl::Span<const int64> broadcast_dimensions) {
3184   return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
3185                                  broadcast_dimensions);
3186 }
3187 
Reduce(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)3188 XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
3189              const XlaComputation& computation,
3190              absl::Span<const int64> dimensions_to_reduce) {
3191   return operand.builder()->Reduce(operand, init_value, computation,
3192                                    dimensions_to_reduce);
3193 }
3194 
3195 // Reduces several arrays simultaneously among the provided dimensions, given
3196 // "computation" as a reduction operator.
Reduce(XlaBuilder * builder,absl::Span<const XlaOp> operands,absl::Span<const XlaOp> init_values,const XlaComputation & computation,absl::Span<const int64> dimensions_to_reduce)3197 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3198              absl::Span<const XlaOp> init_values,
3199              const XlaComputation& computation,
3200              absl::Span<const int64> dimensions_to_reduce) {
3201   return builder->Reduce(operands, init_values, computation,
3202                          dimensions_to_reduce);
3203 }
3204 
ReduceAll(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation)3205 XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
3206                 const XlaComputation& computation) {
3207   return operand.builder()->ReduceAll(operand, init_value, computation);
3208 }
3209 
ReduceWindow(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding)3210 XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
3211                    const XlaComputation& computation,
3212                    absl::Span<const int64> window_dimensions,
3213                    absl::Span<const int64> window_strides, Padding padding) {
3214   return operand.builder()->ReduceWindow(operand, init_value, computation,
3215                                          window_dimensions, window_strides,
3216                                          padding);
3217 }
3218 
ReduceWindowWithGeneralPadding(const XlaOp & operand,const XlaOp & init_value,const XlaComputation & computation,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const int64> base_dilations,absl::Span<const int64> window_dilations,absl::Span<const std::pair<int64,int64>> padding)3219 XlaOp ReduceWindowWithGeneralPadding(
3220     const XlaOp& operand, const XlaOp& init_value,
3221     const XlaComputation& computation,
3222     absl::Span<const int64> window_dimensions,
3223     absl::Span<const int64> window_strides,
3224     absl::Span<const int64> base_dilations,
3225     absl::Span<const int64> window_dilations,
3226     absl::Span<const std::pair<int64, int64>> padding) {
3227   return operand.builder()->ReduceWindowWithGeneralPadding(
3228       operand, init_value, computation, window_dimensions, window_strides,
3229       base_dilations, window_dilations, padding);
3230 }
3231 
CrossReplicaSum(const XlaOp & operand,absl::Span<const ReplicaGroup> replica_groups)3232 XlaOp CrossReplicaSum(const XlaOp& operand,
3233                       absl::Span<const ReplicaGroup> replica_groups) {
3234   return operand.builder()->CrossReplicaSum(operand, replica_groups);
3235 }
3236 
CrossReplicaSum(const XlaOp & operand,const XlaComputation & computation,absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle> & channel_id)3237 XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
3238                       absl::Span<const ReplicaGroup> replica_groups,
3239                       const absl::optional<ChannelHandle>& channel_id) {
3240   return operand.builder()->CrossReplicaSum(operand, computation,
3241                                             replica_groups, channel_id);
3242 }
3243 
AllToAll(const XlaOp & operand,int64 split_dimension,int64 concat_dimension,int64 split_count,const std::vector<ReplicaGroup> & replica_groups)3244 XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
3245                int64 concat_dimension, int64 split_count,
3246                const std::vector<ReplicaGroup>& replica_groups) {
3247   return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
3248                                      split_count, replica_groups);
3249 }
3250 
CollectivePermute(const XlaOp & operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)3251 XlaOp CollectivePermute(
3252     const XlaOp& operand,
3253     const std::vector<std::pair<int64, int64>>& source_target_pairs) {
3254   return operand.builder()->CollectivePermute(operand, source_target_pairs);
3255 }
3256 
ReplicaId(XlaBuilder * builder)3257 XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
3258 
SelectAndScatter(const XlaOp & operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,Padding padding,const XlaOp & source,const XlaOp & init_value,const XlaComputation & scatter)3259 XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
3260                        absl::Span<const int64> window_dimensions,
3261                        absl::Span<const int64> window_strides, Padding padding,
3262                        const XlaOp& source, const XlaOp& init_value,
3263                        const XlaComputation& scatter) {
3264   return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
3265                                              window_strides, padding, source,
3266                                              init_value, scatter);
3267 }
3268 
SelectAndScatterWithGeneralPadding(const XlaOp & operand,const XlaComputation & select,absl::Span<const int64> window_dimensions,absl::Span<const int64> window_strides,absl::Span<const std::pair<int64,int64>> padding,const XlaOp & source,const XlaOp & init_value,const XlaComputation & scatter)3269 XlaOp SelectAndScatterWithGeneralPadding(
3270     const XlaOp& operand, const XlaComputation& select,
3271     absl::Span<const int64> window_dimensions,
3272     absl::Span<const int64> window_strides,
3273     absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
3274     const XlaOp& init_value, const XlaComputation& scatter) {
3275   return operand.builder()->SelectAndScatterWithGeneralPadding(
3276       operand, select, window_dimensions, window_strides, padding, source,
3277       init_value, scatter);
3278 }
3279 
Abs(const XlaOp & operand)3280 XlaOp Abs(const XlaOp& operand) {
3281   return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
3282 }
3283 
Atan2(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3284 XlaOp Atan2(const XlaOp& lhs, const XlaOp& rhs,
3285             absl::Span<const int64> broadcast_dimensions) {
3286   return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
3287                                  broadcast_dimensions);
3288 }
3289 
Exp(const XlaOp & operand)3290 XlaOp Exp(const XlaOp& operand) {
3291   return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
3292 }
Expm1(const XlaOp & operand)3293 XlaOp Expm1(const XlaOp& operand) {
3294   return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
3295 }
Floor(const XlaOp & operand)3296 XlaOp Floor(const XlaOp& operand) {
3297   return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
3298 }
Ceil(const XlaOp & operand)3299 XlaOp Ceil(const XlaOp& operand) {
3300   return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
3301 }
Round(const XlaOp & operand)3302 XlaOp Round(const XlaOp& operand) {
3303   return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
3304 }
Log(const XlaOp & operand)3305 XlaOp Log(const XlaOp& operand) {
3306   return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
3307 }
Log1p(const XlaOp & operand)3308 XlaOp Log1p(const XlaOp& operand) {
3309   return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
3310 }
Sign(const XlaOp & operand)3311 XlaOp Sign(const XlaOp& operand) {
3312   return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
3313 }
Clz(const XlaOp & operand)3314 XlaOp Clz(const XlaOp& operand) {
3315   return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
3316 }
Cos(const XlaOp & operand)3317 XlaOp Cos(const XlaOp& operand) {
3318   return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
3319 }
Sin(const XlaOp & operand)3320 XlaOp Sin(const XlaOp& operand) {
3321   return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
3322 }
Tanh(const XlaOp & operand)3323 XlaOp Tanh(const XlaOp& operand) {
3324   return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
3325 }
Real(const XlaOp & operand)3326 XlaOp Real(const XlaOp& operand) {
3327   return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
3328 }
Imag(const XlaOp & operand)3329 XlaOp Imag(const XlaOp& operand) {
3330   return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
3331 }
Sqrt(const XlaOp & operand)3332 XlaOp Sqrt(const XlaOp& operand) {
3333   return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
3334 }
Rsqrt(const XlaOp & operand)3335 XlaOp Rsqrt(const XlaOp& operand) {
3336   return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
3337 }
3338 
Pow(const XlaOp & lhs,const XlaOp & rhs,absl::Span<const int64> broadcast_dimensions)3339 XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
3340           absl::Span<const int64> broadcast_dimensions) {
3341   return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
3342                                  broadcast_dimensions);
3343 }
3344 
IsFinite(const XlaOp & operand)3345 XlaOp IsFinite(const XlaOp& operand) {
3346   return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
3347 }
3348 
ConvertElementType(const XlaOp & operand,PrimitiveType new_element_type)3349 XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) {
3350   return operand.builder()->ConvertElementType(operand, new_element_type);
3351 }
3352 
BitcastConvertType(const XlaOp & operand,PrimitiveType new_element_type)3353 XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
3354   return operand.builder()->BitcastConvertType(operand, new_element_type);
3355 }
3356 
Neg(const XlaOp & operand)3357 XlaOp Neg(const XlaOp& operand) {
3358   return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
3359 }
3360 
Transpose(const XlaOp & operand,absl::Span<const int64> permutation)3361 XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation) {
3362   return operand.builder()->Transpose(operand, permutation);
3363 }
3364 
Rev(const XlaOp & operand,absl::Span<const int64> dimensions)3365 XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions) {
3366   return operand.builder()->Rev(operand, dimensions);
3367 }
3368 
Sort(const XlaOp & keys,absl::Span<const XlaOp> values,int64 dimension)3369 XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values, int64 dimension) {
3370   return keys.builder()->Sort(keys, values, dimension);
3371 }
3372 
Sort(absl::Span<const XlaOp> operands,const XlaComputation & comparator,int64 dimension,bool is_stable)3373 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
3374            int64 dimension, bool is_stable) {
3375   return operands[0].builder()->Sort(operands, comparator, dimension,
3376                                      is_stable);
3377 }
3378 
Clamp(const XlaOp & min,const XlaOp & operand,const XlaOp & max)3379 XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
3380   return min.builder()->Clamp(min, operand, max);
3381 }
3382 
Map(XlaBuilder * builder,absl::Span<const XlaOp> operands,const XlaComputation & computation,absl::Span<const int64> dimensions,absl::Span<const XlaOp> static_operands)3383 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
3384           const XlaComputation& computation, absl::Span<const int64> dimensions,
3385           absl::Span<const XlaOp> static_operands) {
3386   return builder->Map(operands, computation, dimensions, static_operands);
3387 }
3388 
RngNormal(const XlaOp & mu,const XlaOp & sigma,const Shape & shape)3389 XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape) {
3390   return mu.builder()->RngNormal(mu, sigma, shape);
3391 }
3392 
RngUniform(const XlaOp & a,const XlaOp & b,const Shape & shape)3393 XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape) {
3394   return a.builder()->RngUniform(a, b, shape);
3395 }
3396 
While(const XlaComputation & condition,const XlaComputation & body,const XlaOp & init)3397 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
3398             const XlaOp& init) {
3399   return init.builder()->While(condition, body, init);
3400 }
3401 
Conditional(const XlaOp & predicate,const XlaOp & true_operand,const XlaComputation & true_computation,const XlaOp & false_operand,const XlaComputation & false_computation)3402 XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
3403                   const XlaComputation& true_computation,
3404                   const XlaOp& false_operand,
3405                   const XlaComputation& false_computation) {
3406   return predicate.builder()->Conditional(predicate, true_operand,
3407                                           true_computation, false_operand,
3408                                           false_computation);
3409 }
3410 
Conditional(const XlaOp & branch_index,absl::Span<const XlaComputation * const> branch_computations,absl::Span<const XlaOp> branch_operands)3411 XlaOp Conditional(const XlaOp& branch_index,
3412                   absl::Span<const XlaComputation* const> branch_computations,
3413                   absl::Span<const XlaOp> branch_operands) {
3414   return branch_index.builder()->Conditional(branch_index, branch_computations,
3415                                              branch_operands);
3416 }
3417 
ReducePrecision(const XlaOp & operand,const int exponent_bits,const int mantissa_bits)3418 XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
3419                       const int mantissa_bits) {
3420   return operand.builder()->ReducePrecision(operand, exponent_bits,
3421                                             mantissa_bits);
3422 }
3423 
Gather(const XlaOp & input,const XlaOp & start_indices,const GatherDimensionNumbers & dimension_numbers,absl::Span<const int64> slice_sizes)3424 XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
3425              const GatherDimensionNumbers& dimension_numbers,
3426              absl::Span<const int64> slice_sizes) {
3427   return input.builder()->Gather(input, start_indices, dimension_numbers,
3428                                  slice_sizes);
3429 }
3430 
Scatter(const XlaOp & input,const XlaOp & scatter_indices,const XlaOp & updates,const XlaComputation & update_computation,const ScatterDimensionNumbers & dimension_numbers)3431 XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
3432               const XlaOp& updates, const XlaComputation& update_computation,
3433               const ScatterDimensionNumbers& dimension_numbers) {
3434   return input.builder()->Scatter(input, scatter_indices, updates,
3435                                   update_computation, dimension_numbers);
3436 }
3437 
Send(const XlaOp & operand,const ChannelHandle & handle)3438 void Send(const XlaOp& operand, const ChannelHandle& handle) {
3439   return operand.builder()->Send(operand, handle);
3440 }
3441 
Recv(XlaBuilder * builder,const Shape & shape,const ChannelHandle & handle)3442 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
3443            const ChannelHandle& handle) {
3444   return builder->Recv(shape, handle);
3445 }
3446 
SendWithToken(const XlaOp & operand,const XlaOp & token,const ChannelHandle & handle)3447 XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
3448                     const ChannelHandle& handle) {
3449   return operand.builder()->SendWithToken(operand, token, handle);
3450 }
3451 
RecvWithToken(const XlaOp & token,const Shape & shape,const ChannelHandle & handle)3452 XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
3453                     const ChannelHandle& handle) {
3454   return token.builder()->RecvWithToken(token, shape, handle);
3455 }
3456 
SendToHost(const XlaOp & operand,const XlaOp & token,const Shape & shape_with_layout,const ChannelHandle & handle)3457 XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
3458                  const Shape& shape_with_layout, const ChannelHandle& handle) {
3459   return operand.builder()->SendToHost(operand, token, shape_with_layout,
3460                                        handle);
3461 }
3462 
RecvFromHost(const XlaOp & token,const Shape & shape,const ChannelHandle & handle)3463 XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
3464                    const ChannelHandle& handle) {
3465   return token.builder()->RecvFromHost(token, shape, handle);
3466 }
3467 
InfeedWithToken(const XlaOp & token,const Shape & shape,const string & config)3468 XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
3469                       const string& config) {
3470   return token.builder()->InfeedWithToken(token, shape, config);
3471 }
3472 
OutfeedWithToken(const XlaOp & operand,const XlaOp & token,const Shape & shape_with_layout,const string & outfeed_config)3473 XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
3474                        const Shape& shape_with_layout,
3475                        const string& outfeed_config) {
3476   return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
3477                                              outfeed_config);
3478 }
3479 
CreateToken(XlaBuilder * builder)3480 XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
3481 
AfterAll(XlaBuilder * builder,absl::Span<const XlaOp> tokens)3482 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
3483   return builder->AfterAll(tokens);
3484 }
3485 
BatchNormTraining(const XlaOp & operand,const XlaOp & scale,const XlaOp & offset,float epsilon,int64 feature_index)3486 XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
3487                         const XlaOp& offset, float epsilon,
3488                         int64 feature_index) {
3489   return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
3490                                               feature_index);
3491 }
3492 
BatchNormInference(const XlaOp & operand,const XlaOp & scale,const XlaOp & offset,const XlaOp & mean,const XlaOp & variance,float epsilon,int64 feature_index)3493 XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
3494                          const XlaOp& offset, const XlaOp& mean,
3495                          const XlaOp& variance, float epsilon,
3496                          int64 feature_index) {
3497   return operand.builder()->BatchNormInference(
3498       operand, scale, offset, mean, variance, epsilon, feature_index);
3499 }
3500 
BatchNormGrad(const XlaOp & operand,const XlaOp & scale,const XlaOp & batch_mean,const XlaOp & batch_var,const XlaOp & grad_output,float epsilon,int64 feature_index)3501 XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
3502                     const XlaOp& batch_mean, const XlaOp& batch_var,
3503                     const XlaOp& grad_output, float epsilon,
3504                     int64 feature_index) {
3505   return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
3506                                           grad_output, epsilon, feature_index);
3507 }
3508 
Iota(XlaBuilder * builder,PrimitiveType type,int64 size)3509 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
3510   return builder->Iota(type, size);
3511 }
3512 
Iota(XlaBuilder * builder,const Shape & shape,int64 iota_dimension)3513 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) {
3514   return builder->Iota(shape, iota_dimension);
3515 }
3516 
GetDimensionSize(const XlaOp & operand,int64 dimension)3517 XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension) {
3518   return operand.builder()->GetDimensionSize(operand, dimension);
3519 }
3520 
3521 }  // namespace xla
3522