1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17 
18 #include <algorithm>
19 #include <ostream>
20 #include <set>
21 #include <string>
22 #include <unordered_set>
23 #include <utility>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/escaping.h"
32 #include "absl/strings/numbers.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/string_view.h"
36 #include "absl/types/span.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/protobuf_util.h"
40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
48 #include "tensorflow/compiler/xla/service/name_uniquer.h"
49 #include "tensorflow/compiler/xla/shape_util.h"
50 #include "tensorflow/compiler/xla/status_macros.h"
51 #include "tensorflow/compiler/xla/types.h"
52 #include "tensorflow/compiler/xla/util.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/core/errors.h"
55 #include "tensorflow/core/lib/gtl/map_util.h"
56 #include "tensorflow/core/platform/human_readable_json.h"
57 #include "tensorflow/core/platform/logging.h"
58 
59 namespace xla {
60 
61 using absl::CEscape;
62 using absl::StrAppend;
63 using absl::StrCat;
64 using absl::StrJoin;
65 
66 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)67 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
68     const HloInstructionProto& proto,
69     const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
70     const absl::flat_hash_map<int64, HloComputation*>& computation_map,
71     bool prohibit_empty_literal) {
72   TF_RET_CHECK(!proto.opcode().empty());
73   HloOpcode opcode;
74   auto opcode_or = StringToHloOpcode(proto.opcode());
75   absl::optional<ComparisonDirection> comparison_direction;
76   if (opcode_or.ok()) {
77     opcode = opcode_or.ConsumeValueOrDie();
78   } else {
79     // Unknown opcode. Try auto-upgrading deprecated "less-than",
80     // "greater-than", etc opcodes, which are now rolled into the kCompare
81     // opcode.
82     if (proto.opcode() == "equal-to") {
83       comparison_direction = ComparisonDirection::kEq;
84     } else if (proto.opcode() == "not-equal-to") {
85       comparison_direction = ComparisonDirection::kNe;
86     } else if (proto.opcode() == "greater-than-or-equal-to") {
87       comparison_direction = ComparisonDirection::kGe;
88     } else if (proto.opcode() == "greater-than") {
89       comparison_direction = ComparisonDirection::kGt;
90     } else if (proto.opcode() == "less-than-or-equal-to") {
91       comparison_direction = ComparisonDirection::kLe;
92     } else if (proto.opcode() == "less-than") {
93       comparison_direction = ComparisonDirection::kLt;
94     }
95     if (comparison_direction) {
96       opcode = HloOpcode::kCompare;
97     } else {
98       return InvalidArgument("Unknown opcode: %s", proto.opcode());
99     }
100   }
101 
102   TF_RET_CHECK(proto.has_shape());
103 
104   std::unique_ptr<HloInstruction> instruction;
105   const auto operands = [&instruction_map, &proto](int index) {
106     return instruction_map.at(proto.operand_ids(index));
107   };
108   const auto all_operands = [&instruction_map, &proto]() {
109     std::vector<HloInstruction*> result(proto.operand_ids_size());
110     std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
111                    result.begin(), [&instruction_map](int64 operand_id) {
112                      return instruction_map.at(operand_id);
113                    });
114     return result;
115   };
116   const auto computations = [&computation_map, &proto](int index) {
117     return computation_map.at(proto.called_computation_ids(index));
118   };
119   const auto all_computations = [&computation_map, &proto]() {
120     std::vector<HloComputation*> result(proto.called_computation_ids_size());
121     std::transform(proto.called_computation_ids().begin(),
122                    proto.called_computation_ids().end(), result.begin(),
123                    [&computation_map](int64 computation_id) {
124                      return computation_map.at(computation_id);
125                    });
126     return result;
127   };
128 
129   TF_RET_CHECK(
130       absl::c_all_of(proto.operand_ids(),
131                      [&](int64 id) { return instruction_map.contains(id); }))
132       << proto.name() << " instruction contains invalid operand id(s)";
133 
134   TF_RET_CHECK(
135       absl::c_all_of(proto.called_computation_ids(),
136                      [&](int64 id) { return computation_map.contains(id); }))
137       << proto.name() << " instruction references invalid computation id(s)";
138 
139   Shape shape(proto.shape());
140   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
141 
142   absl::optional<int> arity = HloOpcodeArity(opcode);
143   if (arity) {
144     TF_RET_CHECK(proto.operand_ids_size() == *arity)
145         << proto.opcode() << " instruction should have " << *arity
146         << " operands but sees " << proto.operand_ids_size();
147   }
148 
149   switch (opcode) {
150     // Ops migrated to subclasses.
151     case HloOpcode::kBatchNormTraining:
152       instruction =
153           CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
154                                   proto.epsilon(), proto.feature_index());
155       break;
156     case HloOpcode::kBatchNormInference:
157       instruction = CreateBatchNormInference(
158           shape, operands(0), operands(1), operands(2), operands(3),
159           operands(4), proto.epsilon(), proto.feature_index());
160       break;
161     case HloOpcode::kBatchNormGrad:
162       instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
163                                         operands(2), operands(3), operands(4),
164                                         proto.epsilon(), proto.feature_index());
165       break;
166     case HloOpcode::kFft: {
167       std::vector<int64> fft_length(proto.fft_length().begin(),
168                                     proto.fft_length().end());
169       instruction = CreateFft(shape, operands(0), proto.fft_type(),
170                               absl::Span<const int64>(fft_length));
171       break;
172     }
173     case HloOpcode::kCopyStart: {
174       instruction = CreateCopyStart(shape, operands(0),
175                                     proto.is_cross_program_prefetch());
176       break;
177     }
178     case HloOpcode::kCompare: {
179       // Auto-upgraded from deprecated opcode skips the following.
180       if (!comparison_direction) {
181         TF_ASSIGN_OR_RETURN(
182             comparison_direction,
183             StringToComparisonDirection(proto.comparison_direction()));
184       }
185       auto comparison_type_str = proto.comparison_type();
186       if (!comparison_type_str.empty()) {
187         // If a comparison type is specified, it *must* be valid.
188         TF_ASSIGN_OR_RETURN(auto comparison_type,
189                             StringToComparisonType(comparison_type_str));
190         instruction = CreateCompare(shape, operands(0), operands(1),
191                                     *comparison_direction, comparison_type);
192       } else {
193         // Allow the specify of comparison type to be optional.
194         // The comparison type will be determined by the types of the operands.
195         instruction = CreateCompare(shape, operands(0), operands(1),
196                                     *comparison_direction);
197       }
198       break;
199     }
200     case HloOpcode::kTriangularSolve: {
201       instruction = CreateTriangularSolve(shape, operands(0), operands(1),
202                                           proto.triangular_solve_options());
203       break;
204     }
205     case HloOpcode::kCholesky: {
206       instruction =
207           CreateCholesky(shape, operands(0), proto.cholesky_options());
208       break;
209     }
210     case HloOpcode::kSend:
211       instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
212                                proto.is_host_transfer());
213       break;
214     case HloOpcode::kSendDone:
215       instruction = CreateSendDone(operands(0), proto.is_host_transfer());
216       break;
217     case HloOpcode::kRecv:
218       instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
219                                proto.channel_id(), proto.is_host_transfer());
220       break;
221     case HloOpcode::kRecvDone:
222       instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
223       break;
224     case HloOpcode::kReverse:
225       instruction = CreateReverse(shape, operands(0),
226                                   std::vector<int64>(proto.dimensions().begin(),
227                                                      proto.dimensions().end()));
228       break;
229     case HloOpcode::kConcatenate:
230       TF_RET_CHECK(proto.dimensions_size() == 1)
231           << "Concatenate instruction should have 1 dimension but sees "
232           << proto.dimensions_size();
233       instruction =
234           CreateConcatenate(shape, all_operands(), proto.dimensions(0));
235       break;
236     case HloOpcode::kConditional: {
237       TF_RET_CHECK(proto.called_computation_ids_size() > 0)
238           << "conditional should have at least 1 called computation";
239       if (operands(0)->shape().element_type() == PRED) {
240         TF_RET_CHECK(proto.called_computation_ids_size() == 2)
241             << "conditional should have exactly 2 called computations but got "
242             << proto.called_computation_ids_size();
243       }
244       TF_RET_CHECK(proto.operand_ids_size() ==
245                    proto.called_computation_ids_size() + 1)
246           << "conditional should have one branch_index operand plus one "
247              "operand per called computation but got "
248           << proto.operand_ids_size() << " operands for "
249           << proto.called_computation_ids_size() << " branch computations";
250       auto cond_operands = all_operands();
251       instruction =
252           CreateConditional(shape, cond_operands[0], all_computations(),
253                             absl::MakeSpan(cond_operands).subspan(1));
254       break;
255     }
256     case HloOpcode::kReduce:
257       TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
258           << "Reduce instruction should have an even number of operands but "
259              "sees "
260           << proto.operand_ids_size();
261       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
262           << "Reduce instruction should have 1 called computation but sees "
263           << proto.called_computation_ids_size();
264       {
265         const auto reduce_operands = all_operands();
266         auto inputs = absl::MakeSpan(reduce_operands)
267                           .subspan(0, reduce_operands.size() / 2);
268         auto init_values =
269             absl::MakeSpan(reduce_operands)
270                 .subspan(reduce_operands.size() / 2, reduce_operands.size());
271         instruction =
272             CreateReduce(shape, inputs, init_values,
273                          std::vector<int64>(proto.dimensions().begin(),
274                                             proto.dimensions().end()),
275                          computations(0));
276       }
277       break;
278     case HloOpcode::kSort: {
279       TF_RET_CHECK(proto.operand_ids_size() >= 1)
280           << "Sort instruction should have at least 1 operand but has "
281           << proto.operand_ids_size();
282       TF_RET_CHECK(proto.dimensions().size() == 1)
283           << "Sort instruction should have 1 dimension";
284       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
285           << "Sort instruction should one called computation but sees "
286           << proto.called_computation_ids_size();
287       auto sort_operands = all_operands();
288       instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
289                                computations(0), proto.is_stable());
290       break;
291     }
292     case HloOpcode::kTranspose:
293       instruction =
294           CreateTranspose(shape, operands(0),
295                           std::vector<int64>(proto.dimensions().begin(),
296                                              proto.dimensions().end()));
297       break;
298     case HloOpcode::kBroadcast:
299       instruction =
300           CreateBroadcast(shape, operands(0),
301                           std::vector<int64>(proto.dimensions().begin(),
302                                              proto.dimensions().end()));
303       break;
304     case HloOpcode::kMap:
305       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
306           << "Map instruction should have 1 called computation but sees "
307           << proto.called_computation_ids_size();
308       instruction = CreateMap(shape, all_operands(), computations(0));
309       break;
310     case HloOpcode::kSlice: {
311       std::vector<int64> slice_starts, slice_limits, slice_strides;
312       for (const HloInstructionProto::SliceDimensions& slice_dimensions :
313            proto.slice_dimensions()) {
314         slice_starts.push_back(slice_dimensions.start());
315         slice_limits.push_back(slice_dimensions.limit());
316         slice_strides.push_back(slice_dimensions.stride());
317       }
318       instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
319                                 slice_strides);
320       break;
321     }
322     case HloOpcode::kConstant: {
323       // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
324       if (proto.has_literal()) {
325         TF_ASSIGN_OR_RETURN(
326             auto literal,
327             Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
328         instruction = CreateConstant(std::move(literal));
329         // Literal's shape may have no/different tiling info.
330         TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
331             instruction->shape(), shape))
332             << instruction->shape().ToString(true) << " vs "
333             << shape.ToString(true);
334         *instruction->mutable_shape() = shape;
335       } else {
336         instruction = absl::make_unique<HloConstantInstruction>(shape);
337       }
338       break;
339     }
340     case HloOpcode::kTrace: {
341       TF_RET_CHECK(proto.has_literal());
342       TF_ASSIGN_OR_RETURN(
343           auto literal,
344           Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
345       instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
346       break;
347     }
348     case HloOpcode::kFusion: {
349       // In the proto, fused computations are held exclusively within the
350       // HloInstructionProto and do not appear as an HloComputationProto within
351       // the HloModuleProto.
352       TF_RET_CHECK(!proto.fusion_kind().empty());
353       TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
354                           StringToFusionKind(proto.fusion_kind()));
355 
356       // Find the fused computation and set its fusion instruction.
357       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
358           << "Expect 1 called computation for fusion instruction but sees "
359           << proto.called_computation_ids_size();
360       const int64 fusion_id = proto.called_computation_ids(0);
361       auto* fused_computation =
362           tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
363       TF_RET_CHECK(fused_computation != nullptr)
364           << "No fusion computation with id " << fusion_id;
365       instruction =
366           CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
367       break;
368     }
369     case HloOpcode::kRng:
370       instruction = CreateRng(shape, proto.distribution(), all_operands());
371       break;
372     case HloOpcode::kRngBitGenerator:
373       instruction =
374           CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm());
375       break;
376     case HloOpcode::kRngGetAndUpdateState:
377       instruction = CreateRngGetAndUpdateState(shape, proto.delta());
378       break;
379     case HloOpcode::kParameter:
380       instruction =
381           CreateParameter(proto.parameter_number(), shape, proto.name());
382       if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
383         instruction->set_parameter_replicated_at_leaf_buffers(
384             proto.parameter_replication().replicated_at_leaf_buffers());
385       }
386       break;
387     case HloOpcode::kGetTupleElement:
388       instruction =
389           CreateGetTupleElement(shape, operands(0), proto.tuple_index());
390       break;
391     case HloOpcode::kReducePrecision:
392       instruction = CreateReducePrecision(
393           shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
394       break;
395     case HloOpcode::kInfeed: {
396       TF_RET_CHECK(shape.IsTuple() &&
397                    (ShapeUtil::TupleElementCount(shape) == 2))
398           << "Infeed should have a tuple shape with 2 operands, but has: "
399           << shape;
400       const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
401       instruction =
402           CreateInfeed(data_shape, operands(0), proto.infeed_config());
403     } break;
404     case HloOpcode::kOutfeed: {
405       Shape outfeed_shape(proto.outfeed_shape());
406       TF_RETURN_IF_ERROR(
407           ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
408       instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
409                                   proto.outfeed_config());
410       break;
411     }
412     case HloOpcode::kAllGather: {
413       absl::optional<int64> channel_id;
414       if (proto.channel_id() > 0) {
415         channel_id = proto.channel_id();
416       }
417 
418       TF_RET_CHECK(proto.dimensions_size() == 1)
419           << "AllGather cannot have more than 1 all-gather dimensions";
420       TF_RET_CHECK(all_operands().size() == 1)
421           << "AllGather must have a single operand";
422       int64 all_gather_dimension = proto.dimensions(0);
423       instruction = CreateAllGather(
424           shape, operands(0), all_gather_dimension,
425           std::vector<ReplicaGroup>(proto.replica_groups().begin(),
426                                     proto.replica_groups().end()),
427           proto.constrain_layout(), channel_id, proto.use_global_device_ids());
428       break;
429     }
430     case HloOpcode::kAllReduce: {
431       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
432           << "AllReduce should have 1 called computation but sees "
433           << proto.called_computation_ids_size();
434       TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
435           << "AllReduce cannot have both channel_id() and all_reduce_id()";
436       absl::optional<int64> channel_id;
437       if (proto.channel_id() > 0) {
438         channel_id = proto.channel_id();
439       }
440       if (proto.all_reduce_id() > 0) {
441         channel_id = proto.all_reduce_id();
442       }
443       instruction = CreateAllReduce(
444           shape, all_operands(), computations(0),
445           /*replica_groups=*/
446           std::vector<ReplicaGroup>(proto.replica_groups().begin(),
447                                     proto.replica_groups().end()),
448           /*constrain_layout=*/proto.constrain_layout(),
449           /*channel_id=*/channel_id,
450           /*use_global_device_ids=*/proto.use_global_device_ids());
451       break;
452     }
453     case HloOpcode::kAllToAll: {
454       absl::optional<int64> channel_id;
455       if (proto.channel_id() > 0) {
456         channel_id = proto.channel_id();
457       }
458       absl::optional<int64> split_dimension;
459       if (proto.dimensions_size() > 0) {
460         TF_RET_CHECK(proto.dimensions_size() == 1)
461             << "AllToAll cannot have more than 1 dimension (split dimension)";
462         TF_RET_CHECK(all_operands().size() == 1)
463             << "AllToAll must have a single operand when the split dimension "
464                "is specified";
465         split_dimension = proto.dimensions(0);
466       }
467       instruction = CreateAllToAll(
468           shape, all_operands(),
469           /*replica_groups=*/
470           std::vector<ReplicaGroup>(proto.replica_groups().begin(),
471                                     proto.replica_groups().end()),
472           /*constrain_layout=*/proto.constrain_layout(),
473           /*channel_id=*/channel_id, split_dimension);
474       break;
475     }
476     case HloOpcode::kCollectivePermute:
477     case HloOpcode::kCollectivePermuteStart: {
478       std::vector<std::pair<int64, int64>> source_target_pairs(
479           proto.source_target_pairs_size());
480       absl::optional<int64> channel_id;
481       if (proto.channel_id() > 0) {
482         channel_id = proto.channel_id();
483       }
484       for (int i = 0; i < source_target_pairs.size(); i++) {
485         source_target_pairs[i].first = proto.source_target_pairs(i).source();
486         source_target_pairs[i].second = proto.source_target_pairs(i).target();
487       }
488 
489       if (opcode == HloOpcode::kCollectivePermute) {
490         instruction = CreateCollectivePermute(shape, operands(0),
491                                               source_target_pairs, channel_id);
492       } else if (opcode == HloOpcode::kCollectivePermuteStart) {
493         instruction = CreateCollectivePermuteStart(
494             shape, operands(0), source_target_pairs, channel_id);
495       } else {
496         LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
497                    << "but got " << HloOpcodeString(opcode);
498       }
499       break;
500     }
501     case HloOpcode::kReplicaId: {
502       instruction = CreateReplicaId(shape);
503       break;
504     }
505     case HloOpcode::kPartitionId: {
506       instruction = CreatePartitionId(shape);
507       break;
508     }
509     case HloOpcode::kConvolution: {
510       TF_RET_CHECK(proto.has_window());
511       TF_RET_CHECK(proto.has_convolution_dimension_numbers());
512       PrecisionConfig precision_config = proto.precision_config();
513       precision_config.mutable_operand_precision()->Resize(
514           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
515       instruction = CreateConvolve(
516           shape, operands(0), operands(1),
517           std::max<int64>(proto.feature_group_count(), 1),
518           std::max<int64>(proto.batch_group_count(), 1), proto.window(),
519           proto.convolution_dimension_numbers(), precision_config);
520       break;
521     }
522     case HloOpcode::kReduceWindow:
523       TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
524           << "Reduce window should have an even number of operands but "
525              "sees "
526           << proto.operand_ids_size();
527       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
528           << "ReduceWindow should have 1 called computation but sees "
529           << proto.called_computation_ids_size();
530       {
531         const auto reduce_operands = all_operands();
532         auto inputs = absl::MakeSpan(reduce_operands)
533                           .subspan(0, reduce_operands.size() / 2);
534         auto init_values =
535             absl::MakeSpan(reduce_operands)
536                 .subspan(reduce_operands.size() / 2, reduce_operands.size());
537         instruction = CreateReduceWindow(shape, inputs, init_values,
538                                          proto.window(), computations(0));
539       }
540       break;
541     case HloOpcode::kSelectAndScatter:
542       TF_RET_CHECK(proto.called_computation_ids_size() == 2)
543           << "SelectAndScatter should have 2 called computations but sees "
544           << proto.called_computation_ids_size();
545       instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
546                                            proto.window(), operands(1),
547                                            operands(2), computations(1));
548       break;
549     case HloOpcode::kCustomCall: {
550       if (proto.constrain_layout()) {
551         // A proto RepeatedPtrField cannot be converted to a Span (it is a
552         // vector of pointers essentially) so create a vector of shapes to pass
553         // in.
554         std::vector<Shape> operand_shapes;
555         for (const ShapeProto& shape_proto :
556              proto.operand_shapes_with_layout()) {
557           operand_shapes.emplace_back(shape_proto);
558         }
559         instruction =
560             CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
561                              operand_shapes, proto.backend_config());
562       } else {
563         if (proto.called_computation_ids_size() == 1) {
564           instruction = CreateCustomCall(shape, all_operands(), computations(0),
565                                          proto.custom_call_target(),
566                                          proto.backend_config());
567         } else if (proto.called_computation_ids_size() > 1) {
568           instruction = CreateCustomCall(
569               shape, all_operands(), all_computations(),
570               proto.custom_call_target(), proto.backend_config());
571 
572         } else {
573           instruction = CreateCustomCall(shape, all_operands(),
574                                          proto.custom_call_target(),
575                                          proto.backend_config());
576         }
577       }
578       auto custom_call_instr =
579           Cast<HloCustomCallInstruction>(instruction.get());
580       if (proto.has_window()) {
581         custom_call_instr->set_window(proto.window());
582       }
583       if (proto.has_literal()) {
584         TF_ASSIGN_OR_RETURN(
585             auto literal,
586             Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
587         custom_call_instr->set_literal(std::move(literal));
588       }
589       if (proto.has_convolution_dimension_numbers()) {
590         custom_call_instr->set_convolution_dimension_numbers(
591             proto.convolution_dimension_numbers());
592       }
593       custom_call_instr->set_feature_group_count(
594           std::max(static_cast<int64>(proto.feature_group_count()), int64{1}));
595       custom_call_instr->set_batch_group_count(
596           std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
597       custom_call_instr->set_custom_call_has_side_effect(
598           proto.custom_call_has_side_effect());
599       custom_call_instr->set_padding_type(proto.padding_type());
600 
601       PrecisionConfig precision_config = proto.precision_config();
602       precision_config.mutable_operand_precision()->Resize(
603           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
604       *custom_call_instr->mutable_precision_config() = precision_config;
605       std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
606           output_to_operand_aliasing;
607       for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
608         output_to_operand_aliasing.emplace_back(
609             ShapeIndex(aliasing.output_shape_index().begin(),
610                        aliasing.output_shape_index().end()),
611             std::pair<int64, ShapeIndex>{
612                 aliasing.operand_index(),
613                 ShapeIndex(aliasing.operand_shape_index().begin(),
614                            aliasing.operand_shape_index().end())});
615       }
616       custom_call_instr->set_output_to_operand_aliasing(
617           std::move(output_to_operand_aliasing));
618       break;
619     }
620     case HloOpcode::kPad:
621       TF_RET_CHECK(proto.has_padding_config());
622       instruction =
623           CreatePad(shape, operands(0), operands(1), proto.padding_config());
624       break;
625     case HloOpcode::kDynamicSlice: {
626       std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
627       absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
628       TF_RET_CHECK(proto.operand_ids_size() >= 1)
629           << "DynamicSlice instruction should have at least 1 operands but "
630              "sees "
631           << proto.operand_ids_size();
632       // TODO(b/118437727): Old form, make the check unconditional.
633       if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
634         auto expected_operands = 1 + operands(0)->shape().rank();
635         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
636             << "DynamicSlice instruction should have " << expected_operands
637             << " operands, but has " << proto.operand_ids_size();
638       }
639       const auto& operand_vector = all_operands();
640       instruction = CreateDynamicSlice(
641           shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
642           slice_sizes);
643       break;
644     }
645     case HloOpcode::kDynamicUpdateSlice: {
646       TF_RET_CHECK(proto.operand_ids_size() >= 2)
647           << "DynamicUpdateSlice instruction should have at least 2 operands "
648              "but sees "
649           << proto.operand_ids_size();
650       // TODO(b/118437727): Old form, make the check unconditional.
651       if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
652         auto expected_operands = 2 + operands(0)->shape().rank();
653         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
654             << "DynamicUpdateSlice instruction should have "
655             << expected_operands << " operands, but has "
656             << proto.operand_ids_size();
657       }
658       const auto& operand_vector = all_operands();
659       instruction =
660           CreateDynamicUpdateSlice(shape, operands(0), operands(1),
661                                    absl::MakeSpan(operand_vector).subspan(2));
662 
663       break;
664     }
665     case HloOpcode::kGather: {
666       TF_RET_CHECK(proto.has_gather_dimension_numbers())
667           << "Gather instruction should have GatherDimensionNumbers set.";
668       auto gather_dimension_numbers = absl::make_unique<GatherDimensionNumbers>(
669           proto.gather_dimension_numbers());
670       std::vector<int64> gather_slice_sizes;
671       for (int64 bound : proto.gather_slice_sizes()) {
672         gather_slice_sizes.push_back(bound);
673       }
674       instruction = CreateGather(shape, operands(0), operands(1),
675                                  *gather_dimension_numbers, gather_slice_sizes,
676                                  proto.indices_are_sorted());
677       break;
678     }
679     case HloOpcode::kScatter: {
680       TF_RET_CHECK(proto.has_scatter_dimension_numbers())
681           << "Scatter instruction should have ScatterDimensionNumbers set.";
682       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
683           << "Scatter instruction should have 1 called computation but sees "
684           << proto.called_computation_ids_size();
685       auto scatter_dimension_numbers =
686           absl::make_unique<ScatterDimensionNumbers>(
687               proto.scatter_dimension_numbers());
688       instruction =
689           CreateScatter(shape, operands(0), operands(1), operands(2),
690                         computations(0), *scatter_dimension_numbers,
691                         proto.indices_are_sorted(), proto.unique_indices());
692       break;
693     }
694     case HloOpcode::kIota:
695       TF_RET_CHECK(proto.dimensions_size() == 1)
696           << "Iota instruction should have 1 dimension but sees "
697           << proto.dimensions_size();
698       instruction = CreateIota(shape, proto.dimensions(0));
699       break;
700     case HloOpcode::kDot: {
701       TF_RET_CHECK(proto.has_dot_dimension_numbers())
702           << "Dot instruction should have dot_dimension_numbers.";
703       PrecisionConfig precision_config = proto.precision_config();
704       precision_config.mutable_operand_precision()->Resize(
705           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
706       instruction = absl::make_unique<HloDotInstruction>(
707           shape, operands(0), operands(1), proto.dot_dimension_numbers(),
708           precision_config);
709       break;
710     }
711     case HloOpcode::kDomain: {
712       std::shared_ptr<const HloSharding> entry_hlo_sharding;
713       std::shared_ptr<const HloSharding> exit_hlo_sharding;
714       if (proto.has_domain_entry_sharding()) {
715         TF_ASSIGN_OR_RETURN(
716             HloSharding sharding,
717             HloSharding::FromProto(proto.domain_entry_sharding()));
718         entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
719       }
720       if (proto.has_domain_exit_sharding()) {
721         TF_ASSIGN_OR_RETURN(
722             HloSharding sharding,
723             HloSharding::FromProto(proto.domain_exit_sharding()));
724         exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
725       }
726       instruction = absl::make_unique<HloDomainInstruction>(
727           shape, operands(0),
728           absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
729           absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
730       break;
731     }
732     case HloOpcode::kGetDimensionSize:
733       TF_RET_CHECK(proto.dimensions_size() == 1);
734       instruction =
735           CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
736       break;
737     case HloOpcode::kSetDimensionSize:
738       TF_RET_CHECK(proto.dimensions_size() == 1);
739       instruction = CreateSetDimensionSize(shape, operands(0), operands(1),
740                                            proto.dimensions(0));
741       break;
742     case HloOpcode::kReshape: {
743       int64 inferred_dimension = -1;
744       if (!proto.dimensions().empty()) {
745         inferred_dimension = proto.dimensions()[0];
746       }
747       TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
748                    ShapeUtil::ElementsIn(shape) ==
749                        ShapeUtil::ElementsIn(operands(0)->shape()))
750           << "shape: " << ShapeUtil::HumanString(shape)
751           << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
752       instruction = CreateReshape(shape, operands(0), inferred_dimension);
753       break;
754     }
755     case HloOpcode::kDynamicReshape: {
756       TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
757                    ShapeUtil::ElementsIn(shape) ==
758                        ShapeUtil::ElementsIn(operands(0)->shape()))
759           << "shape: " << ShapeUtil::HumanString(shape)
760           << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
761       const auto& operand_vector = all_operands();
762       instruction = CreateDynamicReshape(
763           shape, operands(0), absl::MakeSpan(operand_vector).subspan(1));
764       break;
765     }
766     default: {
767       instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
768       for (const int64 operand_id : proto.operand_ids()) {
769         instruction->AppendOperand(instruction_map.at(operand_id));
770       }
771       if (instruction->opcode() != HloOpcode::kFusion) {
772         if (instruction->opcode() == HloOpcode::kCall) {
773           TF_RET_CHECK(proto.called_computation_ids_size() == 1)
774               << "Call should have 1 called computation but has "
775               << proto.called_computation_ids_size();
776         }
777         for (const int64 computation_id : proto.called_computation_ids()) {
778           instruction->called_computations_.push_back(
779               computation_map.at(computation_id));
780         }
781       }
782       TF_RET_CHECK(!proto.has_precision_config())
783           << instruction->opcode() << proto.DebugString();
784       TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
785       break;
786     }
787   }
788 
789   for (const int64 predecessor_id : proto.control_predecessor_ids()) {
790     TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
791         << "No instruction with id " << predecessor_id;
792     TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
793                            ->AddControlDependencyTo(instruction.get()));
794   }
795 
796   TF_RET_CHECK(!proto.name().empty());
797   instruction->SetAndSanitizeName(proto.name());
798   instruction->metadata_ = proto.metadata();
799   instruction->backend_config_ = proto.backend_config();
800   instruction->outer_dimension_partitions_.assign(
801       proto.outer_dimension_partitions().begin(),
802       proto.outer_dimension_partitions().end());
803 
804   TF_RET_CHECK(proto.id() >= 0)
805       << "Instruction with negative id: " << proto.id();
806   TF_RET_CHECK(proto.id() <= INT_MAX)
807       << "Instruction with id > INT_MAX: " << proto.id();
808   instruction->unique_id_ = proto.id();
809 
810   if (proto.has_sharding()) {
811     TF_ASSIGN_OR_RETURN(const auto& sharding,
812                         HloSharding::FromProto(proto.sharding()));
813     instruction->set_sharding(sharding);
814   }
815 
816   if (proto.has_frontend_attributes()) {
817     instruction->set_frontend_attributes(proto.frontend_attributes());
818   }
819 
820   return std::move(instruction);
821 }
822 
CreateParameter(int64 parameter_number,const Shape & shape,const string & name)823 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
824     int64 parameter_number, const Shape& shape, const string& name) {
825   return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
826                                                     name);
827 }
828 
CreateTrace(const string & tag,HloInstruction * operand)829 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
830     const string& tag, HloInstruction* operand) {
831   return absl::make_unique<HloTraceInstruction>(tag, operand);
832 }
833 
CreateConstant(Literal literal)834 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
835     Literal literal) {
836   return absl::make_unique<HloConstantInstruction>(std::move(literal));
837 }
838 
CreateIota(const Shape & shape,int64 iota_dimension)839 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
840     const Shape& shape, int64 iota_dimension) {
841   return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
842 }
843 
844 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64 index)845 HloInstruction::CreateGetTupleElement(const Shape& shape,
846                                       HloInstruction* operand, int64 index) {
847   return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
848                                                           index);
849 }
850 
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)851 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
852     const Shape& shape, RandomDistribution distribution,
853     absl::Span<HloInstruction* const> parameters) {
854   return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
855 }
856 
857 /* static */ std::unique_ptr<HloInstruction>
CreateRngGetAndUpdateState(const Shape & shape,int64 delta)858 HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64 delta) {
859   return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
860 }
861 
862 /* static */ std::unique_ptr<HloInstruction>
CreateRngBitGenerator(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)863 HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
864                                       RandomAlgorithm algorithm) {
865   return absl::make_unique<HloRngBitGeneratorInstruction>(shape, state,
866                                                           algorithm);
867 }
868 
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)869 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
870     const Shape& shape, HloOpcode opcode,
871     absl::Span<HloInstruction* const> operands) {
872   if (opcode == HloOpcode::kCopy) {
873     // It is impossible to copy an opaque shape, we don't know how big it is.
874     CHECK(!shape.IsOpaque());
875   }
876   auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
877   for (auto operand : operands) {
878     instruction->AppendOperand(operand);
879   }
880   return instruction;
881 }
882 
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)883 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
884     const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
885   // Only certain opcodes are supported with CreateUnary: opcodes of unary
886   // instructions with no auxiliary fields.
887   switch (opcode) {
888     case HloOpcode::kAbs:
889     case HloOpcode::kRoundNearestAfz:
890     case HloOpcode::kBitcast:
891     case HloOpcode::kCeil:
892     case HloOpcode::kCollectivePermuteDone:
893     case HloOpcode::kCopy:
894     case HloOpcode::kCopyDone:
895     case HloOpcode::kCos:
896     case HloOpcode::kClz:
897     case HloOpcode::kExp:
898     case HloOpcode::kExpm1:
899     case HloOpcode::kFloor:
900     case HloOpcode::kImag:
901     case HloOpcode::kIsFinite:
902     case HloOpcode::kLog:
903     case HloOpcode::kLog1p:
904     case HloOpcode::kNot:
905     case HloOpcode::kNegate:
906     case HloOpcode::kPopulationCount:
907     case HloOpcode::kReal:
908     case HloOpcode::kRsqrt:
909     case HloOpcode::kLogistic:
910     case HloOpcode::kSign:
911     case HloOpcode::kSin:
912     case HloOpcode::kSqrt:
913     case HloOpcode::kCbrt:
914     case HloOpcode::kTanh:
915       break;
916     default:
917       LOG(FATAL) << "Invalid unary instruction opcode "
918                  << HloOpcodeString(opcode);
919   }
920   return CreateNary(shape, opcode, {operand});
921 }
922 
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)923 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
924     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
925     HloInstruction* rhs) {
926   // Only certain opcodes are supported with CreateBinary: opcodes of binary
927   // instructions with no auxiliary fields.
928   switch (opcode) {
929     case HloOpcode::kAdd:
930     case HloOpcode::kAtan2:
931     case HloOpcode::kDivide:
932     case HloOpcode::kComplex:
933     case HloOpcode::kMaximum:
934     case HloOpcode::kMinimum:
935     case HloOpcode::kMultiply:
936     case HloOpcode::kPower:
937     case HloOpcode::kRemainder:
938     case HloOpcode::kSubtract:
939     case HloOpcode::kAnd:
940     case HloOpcode::kOr:
941     case HloOpcode::kXor:
942     case HloOpcode::kShiftLeft:
943     case HloOpcode::kShiftRightArithmetic:
944     case HloOpcode::kShiftRightLogical:
945       break;
946     default:
947       LOG(FATAL) << "Invalid binary instruction opcode "
948                  << HloOpcodeString(opcode);
949   }
950   return CreateNary(shape, opcode, {lhs, rhs});
951 }
952 
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)953 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
954     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
955     HloInstruction* rhs, HloInstruction* ehs) {
956   // Only certain opcodes are supported with CreateTernary: opcodes of ternary
957   // instructions with no auxiliary fields.
958   switch (opcode) {
959     case HloOpcode::kClamp:
960     case HloOpcode::kSelect:
961     case HloOpcode::kTupleSelect:
962       break;
963     default:
964       LOG(FATAL) << "Invalid ternary instruction opcode "
965                  << HloOpcodeString(opcode);
966   }
967   return CreateNary(shape, opcode, {lhs, rhs, ehs});
968 }
969 
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)970 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
971     const Shape& shape, HloOpcode opcode,
972     absl::Span<HloInstruction* const> operands) {
973   CHECK_EQ(HloOpcode::kTuple, opcode);
974   return CreateNary(shape, opcode, operands);
975 }
976 
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)977 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
978     const Shape& shape, absl::Span<HloInstruction* const> operands,
979     HloComputation* map_computation) {
980   return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
981 }
982 
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)983 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
984     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
985     int64 feature_group_count, int64 batch_group_count, const Window& window,
986     const ConvolutionDimensionNumbers& dimension_numbers,
987     const PrecisionConfig& precision_config) {
988   return absl::make_unique<HloConvolutionInstruction>(
989       shape, lhs, rhs, feature_group_count, batch_group_count, window,
990       dimension_numbers, precision_config);
991 }
992 
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)993 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
994     const Shape& shape, HloInstruction* operand, FftType fft_type,
995     absl::Span<const int64> fft_length) {
996   return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
997                                               fft_length);
998 }
999 
CreateCopyStart(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)1000 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
1001     const Shape& shape, HloInstruction* operand,
1002     bool is_cross_program_prefetch) {
1003   return absl::make_unique<HloCopyStartInstruction>(shape, operand,
1004                                                     is_cross_program_prefetch);
1005 }
1006 
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,absl::optional<Comparison::Type> type)1007 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
1008     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1009     ComparisonDirection direction, absl::optional<Comparison::Type> type) {
1010   return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
1011                                                   type);
1012 }
1013 
1014 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)1015 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
1016                                       HloInstruction* b,
1017                                       const TriangularSolveOptions& options) {
1018   return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
1019 }
1020 
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)1021 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
1022     const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
1023   return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
1024 }
1025 
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1026 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
1027     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1028     const DotDimensionNumbers& dimension_numbers,
1029     const PrecisionConfig& precision_config) {
1030   return absl::make_unique<HloDotInstruction>(
1031       shape, lhs, rhs, dimension_numbers, precision_config);
1032 }
1033 
1034 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1035 HloInstruction::CreateReducePrecision(const Shape& shape,
1036                                       HloInstruction* operand,
1037                                       const int exponent_bits,
1038                                       const int mantissa_bits) {
1039   return absl::make_unique<HloReducePrecisionInstruction>(
1040       shape, operand, exponent_bits, mantissa_bits);
1041 }
1042 
CreateAllGather(const Shape & shape,HloInstruction * operand,int64 all_gather_dimension,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1043 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllGather(
1044     const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
1045     const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
1046     const absl::optional<int64>& channel_id, bool use_global_device_ids) {
1047   return absl::make_unique<HloAllGatherInstruction>(
1048       shape, operand, all_gather_dimension, replica_groups, constrain_layout,
1049       channel_id, use_global_device_ids);
1050 }
1051 
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1052 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
1053     const Shape& shape, absl::Span<HloInstruction* const> operands,
1054     HloComputation* reduce_computation,
1055     const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
1056     const absl::optional<int64>& channel_id, bool use_global_device_ids) {
1057   return absl::make_unique<HloAllReduceInstruction>(
1058       shape, operands, reduce_computation, replica_groups, constrain_layout,
1059       channel_id, use_global_device_ids);
1060 }
1061 
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)1062 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
1063     const Shape& shape, absl::Span<HloInstruction* const> operands,
1064     const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
1065     const absl::optional<int64>& channel_id,
1066     const absl::optional<int64>& split_dimension) {
1067   return absl::make_unique<HloAllToAllInstruction>(
1068       shape, operands, replica_groups, constrain_layout, channel_id,
1069       split_dimension);
1070 }
1071 
1072 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)1073 HloInstruction::CreateCollectivePermute(
1074     const Shape& shape, HloInstruction* operand,
1075     const std::vector<std::pair<int64, int64>>& source_target_pairs,
1076     const absl::optional<int64>& channel_id) {
1077   return absl::make_unique<HloCollectivePermuteInstruction>(
1078       HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
1079       channel_id);
1080 }
1081 
1082 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)1083 HloInstruction::CreateCollectivePermuteStart(
1084     const Shape& shape, HloInstruction* operand,
1085     const std::vector<std::pair<int64, int64>>& source_target_pairs,
1086     const absl::optional<int64>& channel_id) {
1087   return absl::make_unique<HloCollectivePermuteInstruction>(
1088       HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
1089       channel_id);
1090 }
1091 
CreateReplicaId(const Shape & shape)1092 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
1093     const Shape& shape) {
1094   CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1095       << "HloInstruction replica-id must have a shape of u32[], but "
1096       << shape.ToString() << " is specified";
1097   return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
1098 }
1099 
CreatePartitionId(const Shape & shape)1100 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
1101     const Shape& shape) {
1102   CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1103       << "HloInstruction partition-id must have a shape of u32[], but "
1104       << shape.ToString() << " is specified";
1105   return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
1106 }
1107 
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)1108 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
1109     const Shape& infeed_shape, HloInstruction* token_operand,
1110     const string& config) {
1111   return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
1112                                                  config);
1113 }
1114 
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1115 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
1116     const Shape& outfeed_shape, HloInstruction* operand,
1117     HloInstruction* token_operand, absl::string_view outfeed_config) {
1118   return absl::make_unique<HloOutfeedInstruction>(
1119       outfeed_shape, operand, token_operand, outfeed_config);
1120 }
1121 
CreateSend(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)1122 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
1123     HloInstruction* operand, HloInstruction* token, int64 channel_id,
1124     bool is_host_transfer) {
1125   return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
1126                                                is_host_transfer);
1127 }
1128 
CreateSendDone(HloInstruction * operand,bool is_host_transfer)1129 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
1130     HloInstruction* operand, bool is_host_transfer) {
1131   auto send_operand = DynCast<HloSendInstruction>(operand);
1132   CHECK(send_operand != nullptr)
1133       << "SendDone must take the context operand from Send";
1134   return absl::make_unique<HloSendDoneInstruction>(send_operand,
1135                                                    is_host_transfer);
1136 }
1137 
CreateRecv(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)1138 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
1139     const Shape& shape, HloInstruction* token, int64 channel_id,
1140     bool is_host_transfer) {
1141   return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
1142                                                is_host_transfer);
1143 }
1144 
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)1145 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
1146     HloInstruction* operand, bool is_host_transfer) {
1147   auto recv_operand = DynCast<HloRecvInstruction>(operand);
1148   CHECK(recv_operand != nullptr)
1149       << "RecvDone must take the context operand from Recv";
1150   return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
1151                                                    is_host_transfer);
1152 }
1153 
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1154 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
1155     const Shape& shape, HloInstruction* operand,
1156     absl::Span<const int64> dimensions) {
1157   return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
1158 }
1159 
CreateAfterAll(absl::Span<HloInstruction * const> operands)1160 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
1161     absl::Span<HloInstruction* const> operands) {
1162   CHECK(!operands.empty());
1163   auto instruction = absl::WrapUnique(
1164       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1165   for (auto operand : operands) {
1166     instruction->AppendOperand(operand);
1167   }
1168   return instruction;
1169 }
1170 
CreateToken()1171 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
1172   return absl::WrapUnique(
1173       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1174 }
1175 
1176 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)1177 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
1178                                     HloInstruction* token_operand) {
1179   auto instruction = absl::WrapUnique(
1180       new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
1181   instruction->AppendOperand(data_operand);
1182   instruction->AppendOperand(token_operand);
1183   return instruction;
1184 }
1185 
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)1186 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
1187     const Shape& shape, HloComputation* condition, HloComputation* body,
1188     HloInstruction* init) {
1189   auto instruction =
1190       absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
1191   instruction->AppendOperand(init);
1192   // Body comes before condition computation in the vector.
1193   instruction->called_computations_.push_back(body);
1194   instruction->called_computations_.push_back(condition);
1195   return instruction;
1196 }
1197 
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)1198 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1199     const Shape& shape, HloInstruction* pred,
1200     HloInstruction* true_computation_arg, HloComputation* true_computation,
1201     HloInstruction* false_computation_arg, HloComputation* false_computation) {
1202   auto instruction =
1203       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1204   instruction->AppendOperand(pred);
1205   instruction->AppendOperand(true_computation_arg);
1206   instruction->AppendOperand(false_computation_arg);
1207   // In called_computations_, the index of true_computation must be 0 and that
1208   // of false computation must be 1, as defined by kTrueComputationIndex and
1209   // kFalseComputationIndex.
1210   instruction->called_computations_.push_back(true_computation);
1211   instruction->called_computations_.push_back(false_computation);
1212   return instruction;
1213 }
1214 
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)1215 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1216     const Shape& shape, HloInstruction* branch_index,
1217     absl::Span<HloComputation* const> branch_computations,
1218     absl::Span<HloInstruction* const> branch_computation_args) {
1219   auto instruction =
1220       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1221   instruction->AppendOperand(branch_index);
1222   CHECK_EQ(branch_computations.size(), branch_computation_args.size());
1223   for (int i = 0; i < branch_computations.size(); ++i) {
1224     instruction->called_computations_.push_back(branch_computations[i]);
1225     instruction->AppendOperand(branch_computation_args[i]);
1226   }
1227   return instruction;
1228 }
1229 
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1230 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
1231     const Shape& shape, HloInstruction* operand,
1232     absl::Span<const int64> start_indices,
1233     absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
1234   return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
1235                                                 limit_indices, strides);
1236 }
1237 
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)1238 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
1239     const Shape& shape, HloInstruction* operand,
1240     absl::Span<HloInstruction* const> start_indices,
1241     absl::Span<const int64> slice_sizes) {
1242   return absl::make_unique<HloDynamicSliceInstruction>(
1243       shape, operand, start_indices, slice_sizes);
1244 }
1245 
1246 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1247 HloInstruction::CreateDynamicUpdateSlice(
1248     const Shape& shape, HloInstruction* operand, HloInstruction* update,
1249     absl::Span<HloInstruction* const> start_indices) {
1250   return absl::make_unique<HloDynamicUpdateSliceInstruction>(
1251       shape, operand, update, start_indices);
1252 }
1253 
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)1254 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1255     const Shape& shape, absl::Span<HloInstruction* const> operands,
1256     int64 dimension) {
1257   return absl::make_unique<HloConcatenateInstruction>(shape, operands,
1258                                                       dimension);
1259 }
1260 
CreateConvert(const Shape & shape,HloInstruction * operand)1261 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1262     const Shape& shape, HloInstruction* operand) {
1263   auto instruction =
1264       absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1265   instruction->AppendOperand(operand);
1266   return instruction;
1267 }
1268 
1269 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1270 HloInstruction::CreateBitcastConvert(const Shape& shape,
1271                                      HloInstruction* operand) {
1272   auto instruction =
1273       absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1274   instruction->AppendOperand(operand);
1275   return instruction;
1276 }
1277 
CreateBitcast(const Shape & shape,HloInstruction * operand)1278 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBitcast(
1279     const Shape& shape, HloInstruction* operand) {
1280   auto instruction =
1281       absl::WrapUnique(new HloInstruction(HloOpcode::kBitcast, shape));
1282   instruction->AppendOperand(operand);
1283   return instruction;
1284 }
1285 
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1286 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1287     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1288     absl::Span<const int64> dimensions_to_reduce,
1289     HloComputation* reduce_computation) {
1290   auto instruction = absl::WrapUnique(new HloReduceInstruction(
1291       shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1292   return std::move(instruction);
1293 }
1294 
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1295 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1296     const Shape& shape, absl::Span<HloInstruction* const> operands,
1297     absl::Span<HloInstruction* const> init_values,
1298     absl::Span<const int64> dimensions_to_reduce,
1299     HloComputation* reduce_computation) {
1300   std::vector<HloInstruction*> all_args;
1301   all_args.reserve(operands.size() * 2);
1302   all_args.insert(all_args.end(), operands.begin(), operands.end());
1303   all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1304   return absl::make_unique<HloReduceInstruction>(
1305       shape, all_args, dimensions_to_reduce, reduce_computation);
1306 }
1307 
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1308 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1309     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1310     const Window& window, HloComputation* reduce_computation) {
1311   return absl::make_unique<HloReduceWindowInstruction>(
1312       shape, operand, init_value, window, reduce_computation);
1313 }
1314 
CreateReduceWindow(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)1315 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1316     const Shape& shape, absl::Span<HloInstruction* const> operands,
1317     absl::Span<HloInstruction* const> init_values, const Window& window,
1318     HloComputation* reduce_computation) {
1319   return absl::make_unique<HloReduceWindowInstruction>(
1320       shape, operands, init_values, window, reduce_computation);
1321 }
1322 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)1323 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1324                                         HloInstruction* operand,
1325                                         HloInstruction* scale,
1326                                         HloInstruction* offset, float epsilon,
1327                                         int64 feature_index) {
1328   return absl::make_unique<HloBatchNormTrainingInstruction>(
1329       shape, operand, scale, offset, epsilon, feature_index);
1330 }
1331 
1332 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)1333 HloInstruction::CreateBatchNormInference(
1334     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1335     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1336     float epsilon, int64 feature_index) {
1337   return absl::make_unique<HloBatchNormInferenceInstruction>(
1338       shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1339 }
1340 
1341 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)1342 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1343                                     HloInstruction* scale, HloInstruction* mean,
1344                                     HloInstruction* variance,
1345                                     HloInstruction* grad_output, float epsilon,
1346                                     int64 feature_index) {
1347   return absl::make_unique<HloBatchNormGradInstruction>(
1348       shape, operand, scale, mean, variance, grad_output, epsilon,
1349       feature_index);
1350 }
1351 
1352 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1353 HloInstruction::CreateSelectAndScatter(
1354     const Shape& shape, HloInstruction* operand, HloComputation* select,
1355     const Window& window, HloInstruction* source, HloInstruction* init_value,
1356     HloComputation* scatter) {
1357   return absl::make_unique<HloSelectAndScatterInstruction>(
1358       shape, operand, select, window, source, init_value, scatter);
1359 }
1360 
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimensions)1361 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1362     const Shape& shape, HloInstruction* operand,
1363     absl::Span<const int64> broadcast_dimensions) {
1364   return absl::make_unique<HloBroadcastInstruction>(shape, operand,
1365                                                     broadcast_dimensions);
1366 }
1367 
1368 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64 dimension)1369 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1370                                        HloInstruction* operand,
1371                                        int64 dimension) {
1372   return absl::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1373                                                            dimension);
1374 }
1375 
1376 /* static */ std::unique_ptr<HloInstruction>
CreateSetDimensionSize(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64 dimension)1377 HloInstruction::CreateSetDimensionSize(const Shape& shape,
1378                                        HloInstruction* operand,
1379                                        HloInstruction* val, int64 dimension) {
1380   return absl::make_unique<HloSetDimensionSizeInstruction>(shape, operand, val,
1381                                                            dimension);
1382 }
1383 
1384 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1385 HloInstruction::CreateBroadcastSequence(
1386     const Shape& output_shape, HloInstruction* operand,
1387     const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1388         adder) {
1389   CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1390         operand->shape().rank() == output_shape.rank());
1391   Shape broadcast_shape = ShapeUtil::ChangeElementType(
1392       output_shape, operand->shape().element_type());
1393   // Do explicit broadcast for scalar.
1394   if (ShapeUtil::IsScalar(operand->shape())) {
1395     auto broadcast =
1396         HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1397     broadcast->set_metadata(operand->metadata());
1398     if (operand->has_sharding()) {
1399       broadcast->set_sharding(operand->sharding());
1400     }
1401     broadcast->set_frontend_attributes(operand->frontend_attributes());
1402     return broadcast;
1403   }
1404   // Do explicit broadcast for degenerate broadcast.
1405   std::vector<int64> broadcast_dimensions;
1406   std::vector<int64> reshaped_dimensions;
1407   for (int i = 0; i < operand->shape().rank(); i++) {
1408     if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1409       broadcast_dimensions.push_back(i);
1410       reshaped_dimensions.push_back(operand->shape().dimensions(i));
1411     } else {
1412       CHECK_EQ(operand->shape().dimensions(i), 1)
1413           << "An explicit broadcast sequence requires the broadcasted "
1414              "dimensions to be trivial; operand: "
1415           << operand->ToString() << "; output_shape: " << output_shape;
1416     }
1417   }
1418   // Eliminate the size one dimensions.
1419   HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1420       ShapeUtil::MakeShape(operand->shape().element_type(),
1421                            reshaped_dimensions),
1422       operand));
1423   reshaped_operand->set_metadata(operand->metadata());
1424   if (operand->has_sharding()) {
1425     reshaped_operand->set_sharding(operand->sharding());
1426   }
1427   reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
1428   // Broadcast 'reshape' up to the larger size.
1429   auto broadcast = HloInstruction::CreateBroadcast(
1430       broadcast_shape, reshaped_operand, broadcast_dimensions);
1431   broadcast->set_metadata(operand->metadata());
1432   if (operand->has_sharding()) {
1433     broadcast->set_sharding(operand->sharding());
1434   }
1435   broadcast->set_frontend_attributes(operand->frontend_attributes());
1436   return broadcast;
1437 }
1438 
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1439 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1440     const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1441     const PaddingConfig& padding_config) {
1442   return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
1443                                               padding_config);
1444 }
1445 
CreateReshape(const Shape & shape,HloInstruction * operand,int64 inferred_dimension)1446 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1447     const Shape& shape, HloInstruction* operand, int64 inferred_dimension) {
1448   CHECK_EQ(ShapeUtil::ElementsIn(shape),
1449            ShapeUtil::ElementsIn(operand->shape()))
1450       << "shape: " << ShapeUtil::HumanString(shape)
1451       << " operand: " << ShapeUtil::HumanString(operand->shape());
1452 
1453   return absl::make_unique<HloReshapeInstruction>(shape, operand,
1454                                                   inferred_dimension);
1455 }
1456 
1457 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicReshape(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1458 HloInstruction::CreateDynamicReshape(
1459     const Shape& shape, HloInstruction* data_operand,
1460     absl::Span<HloInstruction* const> dim_sizes) {
1461   CHECK_EQ(ShapeUtil::ElementsIn(shape),
1462            ShapeUtil::ElementsIn(data_operand[0].shape()))
1463       << "shape: " << ShapeUtil::HumanString(shape)
1464       << " operand: " << ShapeUtil::HumanString(data_operand[0].shape());
1465   CHECK_EQ(shape.rank(), dim_sizes.size());
1466   return absl::make_unique<HloDynamicReshapeInstruction>(shape, data_operand,
1467                                                          dim_sizes);
1468 }
1469 
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1470 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1471     const Shape& shape, HloInstruction* operand,
1472     absl::Span<const int64> dimensions) {
1473   return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1474 }
1475 
CreateSort(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1476 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1477     const Shape& shape, int64 dimension,
1478     absl::Span<HloInstruction* const> operands, HloComputation* compare,
1479     bool is_stable) {
1480   return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
1481                                                compare, is_stable);
1482 }
1483 
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1484 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1485     const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1486   return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
1487                                                  fused_root);
1488 }
1489 
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1490 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1491     const Shape& shape, FusionKind fusion_kind,
1492     absl::Span<HloInstruction* const> operands,
1493     HloComputation* fusion_computation) {
1494   return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1495                                                  fusion_computation);
1496 }
1497 
set_single_sharding(const HloSharding & sharding)1498 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1499   CHECK(!sharding.IsTuple()) << sharding;
1500   if (shape().IsTuple()) {
1501     set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1502   } else {
1503     set_sharding(sharding);
1504   }
1505 }
1506 
SetupDerivedInstruction(HloInstruction * derived_instruction) const1507 void HloInstruction::SetupDerivedInstruction(
1508     HloInstruction* derived_instruction) const {
1509   if (sharding_ != nullptr &&
1510       ShapeUtil::CompatibleKind(shape_, derived_instruction->shape())) {
1511     // Only copy sharding if the tuple tree shape of the two instruction is
1512     // compatible because copying it between differently shaped instructions
1513     // can produce invalid shardings.
1514     derived_instruction->set_sharding(*sharding_);
1515   } else {
1516     derived_instruction->clear_sharding();
1517   }
1518   derived_instruction->set_metadata(metadata_);
1519   derived_instruction->set_frontend_attributes(frontend_attributes_);
1520 }
1521 
HasSideEffectNoRecurse() const1522 bool HloInstruction::HasSideEffectNoRecurse() const {
1523   switch (opcode_) {
1524     case HloOpcode::kSend:
1525     case HloOpcode::kSendDone:
1526     case HloOpcode::kRecv:
1527     case HloOpcode::kRecvDone:
1528     case HloOpcode::kRng:
1529     case HloOpcode::kRngGetAndUpdateState:
1530     case HloOpcode::kInfeed:
1531     case HloOpcode::kOutfeed:
1532     case HloOpcode::kTrace:
1533       return true;
1534     case HloOpcode::kAllReduce:
1535       return channel_id().has_value() ||
1536              Cast<HloAllReduceInstruction>(this)->constrain_layout();
1537     case HloOpcode::kAllToAll:
1538       return Cast<HloAllToAllInstruction>(this)->constrain_layout();
1539     case HloOpcode::kCustomCall:
1540       return Cast<HloCustomCallInstruction>(this)
1541           ->custom_call_has_side_effect();
1542     default:
1543       return false;
1544   }
1545 }
1546 
HasSideEffect() const1547 bool HloInstruction::HasSideEffect() const {
1548   if (HasSideEffectNoRecurse()) {
1549     return true;
1550   }
1551   // Check if any of the called computations has a side effect.
1552   for (const auto& computation : called_computations()) {
1553     if (computation->HasSideEffect()) {
1554       return true;
1555     }
1556   }
1557   return false;
1558 }
1559 
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1560 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1561     const Shape& shape, absl::Span<HloInstruction* const> operands,
1562     HloComputation* computation) {
1563   std::unique_ptr<HloInstruction> instruction =
1564       absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
1565   for (auto operand : operands) {
1566     instruction->AppendOperand(operand);
1567   }
1568   instruction->called_computations_.push_back(computation);
1569   return instruction;
1570 }
1571 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque)1572 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1573     const Shape& shape, absl::Span<HloInstruction* const> operands,
1574     absl::string_view custom_call_target, string opaque) {
1575   return absl::make_unique<HloCustomCallInstruction>(
1576       shape, operands, custom_call_target, std::move(opaque));
1577 }
1578 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,string opaque)1579 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1580     const Shape& shape, absl::Span<HloInstruction* const> operands,
1581     HloComputation* to_apply, absl::string_view custom_call_target,
1582     string opaque) {
1583   return absl::make_unique<HloCustomCallInstruction>(
1584       shape, operands, to_apply, custom_call_target, std::move(opaque));
1585 }
1586 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,string opaque)1587 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1588     const Shape& shape, absl::Span<HloInstruction* const> operands,
1589     absl::Span<HloComputation* const> called_computations,
1590     absl::string_view custom_call_target, string opaque) {
1591   return absl::make_unique<HloCustomCallInstruction>(
1592       shape, operands, called_computations, custom_call_target,
1593       std::move(opaque));
1594 }
1595 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,string opaque)1596 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1597     const Shape& shape, absl::Span<HloInstruction* const> operands,
1598     absl::string_view custom_call_target,
1599     absl::Span<const Shape> operand_shapes_with_layout, string opaque) {
1600   return absl::make_unique<HloCustomCallInstruction>(
1601       shape, operands, custom_call_target, std::move(opaque),
1602       operand_shapes_with_layout);
1603 }
1604 
CreateTuple(absl::Span<HloInstruction * const> elements)1605 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1606     absl::Span<HloInstruction* const> elements) {
1607   std::vector<Shape> element_shapes;
1608   for (auto element : elements) {
1609     element_shapes.push_back(element->shape());
1610   }
1611   Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
1612   return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1613 }
1614 
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)1615 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1616     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1617     const GatherDimensionNumbers& gather_dim_numbers,
1618     absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
1619   return absl::make_unique<HloGatherInstruction>(
1620       shape, operand, start_indices, gather_dim_numbers, slice_sizes,
1621       indices_are_sorted);
1622 }
1623 
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1624 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1625     const Shape& shape, HloInstruction* operand,
1626     HloInstruction* scatter_indices, HloInstruction* updates,
1627     HloComputation* update_computation,
1628     const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1629     bool unique_indices) {
1630   return absl::make_unique<HloScatterInstruction>(
1631       shape, operand, scatter_indices, updates, update_computation,
1632       scatter_dim_numbers, indices_are_sorted, unique_indices);
1633 }
1634 
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1635 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1636     const Shape& shape, HloInstruction* operand,
1637     std::unique_ptr<DomainMetadata> operand_side_metadata,
1638     std::unique_ptr<DomainMetadata> user_side_metadata) {
1639   return absl::make_unique<HloDomainInstruction>(
1640       shape, operand, std::move(operand_side_metadata),
1641       std::move(user_side_metadata));
1642 }
1643 
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1644 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1645     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1646     HloCloneContext* context) const {
1647   VLOG(3) << "CloneWithNewOperands:\n  " << ToString();
1648   VLOG(3) << "  new operands:";
1649   for (const HloInstruction* new_operand : new_operands) {
1650     VLOG(3) << "    %" << new_operand->name();
1651   }
1652 
1653   std::unique_ptr<HloInstruction> clone;
1654   // Explicitly call the factory for the instruction type. This is more robust
1655   // in the face of code changes than copying fields explicitly. This also
1656   // properly sets the user fields of the operands.
1657   switch (opcode_) {
1658     // Ops migrated to subclasses.
1659     // TODO(b/80131774): Remove this switch when migration is complete.
1660     case HloOpcode::kBatchNormTraining:
1661     case HloOpcode::kBatchNormInference:
1662     case HloOpcode::kBatchNormGrad:
1663     case HloOpcode::kFft:
1664     case HloOpcode::kCompare:
1665     case HloOpcode::kCopyStart:
1666     case HloOpcode::kSend:
1667     case HloOpcode::kSendDone:
1668     case HloOpcode::kRecv:
1669     case HloOpcode::kRecvDone:
1670     case HloOpcode::kReverse:
1671     case HloOpcode::kConcatenate:
1672     case HloOpcode::kReduce:
1673     case HloOpcode::kTranspose:
1674     case HloOpcode::kBroadcast:
1675     case HloOpcode::kReshape:
1676     case HloOpcode::kDynamicReshape:
1677     case HloOpcode::kMap:
1678     case HloOpcode::kSlice:
1679     case HloOpcode::kConstant:
1680     case HloOpcode::kTrace:
1681     case HloOpcode::kFusion:
1682     case HloOpcode::kRng:
1683     case HloOpcode::kRngBitGenerator:
1684     case HloOpcode::kRngGetAndUpdateState:
1685     case HloOpcode::kParameter:
1686     case HloOpcode::kGetTupleElement:
1687     case HloOpcode::kReducePrecision:
1688     case HloOpcode::kAllGather:
1689     case HloOpcode::kAllReduce:
1690     case HloOpcode::kAllToAll:
1691     case HloOpcode::kCollectivePermute:
1692     case HloOpcode::kCollectivePermuteStart:
1693     case HloOpcode::kInfeed:
1694     case HloOpcode::kOutfeed:
1695     case HloOpcode::kConvolution:
1696     case HloOpcode::kCustomCall:
1697     case HloOpcode::kReduceWindow:
1698     case HloOpcode::kSelectAndScatter:
1699     case HloOpcode::kPad:
1700     case HloOpcode::kDynamicSlice:
1701     case HloOpcode::kSort:
1702     case HloOpcode::kGather:
1703     case HloOpcode::kScatter:
1704     case HloOpcode::kIota:
1705     case HloOpcode::kDot:
1706     case HloOpcode::kDomain:
1707     case HloOpcode::kGetDimensionSize:
1708     case HloOpcode::kSetDimensionSize:
1709     case HloOpcode::kTriangularSolve:
1710     case HloOpcode::kCholesky:
1711       clone = CloneWithNewOperandsImpl(shape, new_operands, context);
1712       break;
1713     // Unary ops.
1714     case HloOpcode::kAbs:
1715     case HloOpcode::kRoundNearestAfz:
1716     case HloOpcode::kBitcast:
1717     case HloOpcode::kCeil:
1718     case HloOpcode::kClz:
1719     case HloOpcode::kCollectivePermuteDone:
1720     case HloOpcode::kCopy:
1721     case HloOpcode::kCopyDone:
1722     case HloOpcode::kCos:
1723     case HloOpcode::kExp:
1724     case HloOpcode::kExpm1:
1725     case HloOpcode::kImag:
1726     case HloOpcode::kIsFinite:
1727     case HloOpcode::kFloor:
1728     case HloOpcode::kLog:
1729     case HloOpcode::kLog1p:
1730     case HloOpcode::kNot:
1731     case HloOpcode::kNegate:
1732     case HloOpcode::kPopulationCount:
1733     case HloOpcode::kReal:
1734     case HloOpcode::kRsqrt:
1735     case HloOpcode::kLogistic:
1736     case HloOpcode::kSign:
1737     case HloOpcode::kSin:
1738     case HloOpcode::kSqrt:
1739     case HloOpcode::kCbrt:
1740     case HloOpcode::kTanh:
1741       CHECK_EQ(new_operands.size(), 1);
1742       clone = CreateUnary(shape, opcode_, new_operands[0]);
1743       break;
1744     // Binary ops.
1745     case HloOpcode::kAdd:
1746     case HloOpcode::kAtan2:
1747     case HloOpcode::kComplex:
1748     case HloOpcode::kDivide:
1749     case HloOpcode::kMultiply:
1750     case HloOpcode::kSubtract:
1751     case HloOpcode::kMaximum:
1752     case HloOpcode::kMinimum:
1753     case HloOpcode::kPower:
1754     case HloOpcode::kRemainder:
1755     case HloOpcode::kAnd:
1756     case HloOpcode::kOr:
1757     case HloOpcode::kXor:
1758     case HloOpcode::kShiftLeft:
1759     case HloOpcode::kShiftRightArithmetic:
1760     case HloOpcode::kShiftRightLogical:
1761       CHECK_EQ(new_operands.size(), 2);
1762       clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
1763       break;
1764     // Ternary ops.
1765     case HloOpcode::kClamp:
1766     case HloOpcode::kSelect:
1767     case HloOpcode::kTupleSelect:
1768       CHECK_EQ(new_operands.size(), 3);
1769       clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
1770                             new_operands[2]);
1771       break;
1772     // Other supported ops.
1773     case HloOpcode::kCall:
1774       clone = CreateCall(shape, new_operands, to_apply());
1775       break;
1776     case HloOpcode::kConvert:
1777       CHECK_EQ(new_operands.size(), 1);
1778       clone = CreateConvert(shape, new_operands[0]);
1779       break;
1780     case HloOpcode::kBitcastConvert:
1781       CHECK_EQ(new_operands.size(), 1);
1782       clone = CreateBitcastConvert(shape, new_operands[0]);
1783       break;
1784     case HloOpcode::kDynamicUpdateSlice:
1785       clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
1786                                        new_operands.subspan(2));
1787       break;
1788     case HloOpcode::kTuple:
1789       clone = CreateTuple(new_operands);
1790       *clone->mutable_shape() = shape;
1791       break;
1792     case HloOpcode::kWhile:
1793       CHECK_EQ(new_operands.size(), 1);
1794       clone =
1795           CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
1796       break;
1797     case HloOpcode::kConditional:
1798       CHECK_EQ(new_operands.size(), branch_count() + 1);
1799       clone = CreateConditional(shape, new_operands[0],
1800                                 absl::MakeSpan(branch_computations()),
1801                                 new_operands.subspan(1));
1802       break;
1803     case HloOpcode::kAfterAll:
1804       if (new_operands.empty()) {
1805         clone = CreateToken();
1806       } else {
1807         clone = CreateAfterAll(new_operands);
1808       }
1809       break;
1810     case HloOpcode::kAddDependency:
1811       CHECK_EQ(new_operands.size(), 2);
1812       clone = CreateAddDependency(new_operands[0], new_operands[1]);
1813       break;
1814     case HloOpcode::kReplicaId:
1815       CHECK_EQ(new_operands.size(), 0);
1816       clone = CreateReplicaId(shape);
1817       break;
1818     case HloOpcode::kPartitionId:
1819       CHECK_EQ(new_operands.size(), 0);
1820       clone = CreatePartitionId(shape);
1821       break;
1822   }
1823   // SetupDerivedInstruction will setup the precision_config_ field.
1824   SetupDerivedInstruction(clone.get());
1825   clone->set_parent(parent_);
1826   clone->set_outer_dimension_partitions(outer_dimension_partitions_);
1827   clone->set_raw_backend_config_string(backend_config_);
1828   if (context != nullptr) {
1829     context->MapInstruction(this, clone.get());
1830     clone->ReplaceCalledComputations([&](HloComputation* callee) {
1831       return callee->parent() != context->module()
1832                  ? context->module()->DeepCloneComputation(callee, context)
1833                  : callee;
1834     });
1835   }
1836   return clone;
1837 }
1838 
DetachFromOperandsAndUsers()1839 void HloInstruction::DetachFromOperandsAndUsers() {
1840   if (cleaned_up_) {
1841     return;
1842   }
1843   cleaned_up_ = true;
1844   // Detach from operands. An instruction may be repeated as an operand. To
1845   // avoid calling RemoveUser twice on the same operand, check before remove.
1846   for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1847     HloInstruction* operand = operands_[operand_num];
1848     if (operand == nullptr) {
1849       continue;
1850     }
1851     if (operand->user_map_.find(this) != operand->user_map_.end()) {
1852       operand->RemoveUser(this);
1853     }
1854     operands_[operand_num] = nullptr;
1855   }
1856 
1857   // Update users. Set `nullptr` to the corresponding operand slot for users.
1858   for (auto& user : this->users()) {
1859     for (int i = 0; i < user->operand_count(); ++i) {
1860       if (user->operands_[i] == this) {
1861         user->operands_[i] = nullptr;
1862       }
1863     }
1864   }
1865 }
1866 
CloneWithNewShape(const Shape & shape,const string & suffix,HloCloneContext * context) const1867 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewShape(
1868     const Shape& shape, const string& suffix, HloCloneContext* context) const {
1869   std::unique_ptr<HloInstruction> clone =
1870       CloneWithNewOperands(shape, operands_, context);
1871   if (suffix.empty()) {
1872     clone->name_ = name();
1873   } else {
1874     // If an instruction is cloned multiple times avoid names like
1875     // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
1876     // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
1877     // clone of foo.suffix2 is named foo.suffix3 and so on.
1878     const string dot_suffix = "." + suffix;
1879     size_t index = name().rfind(dot_suffix);
1880     if (index == string::npos) {
1881       // Existing name does not include ".suffix".
1882       clone->name_ = name() + dot_suffix;
1883     } else {
1884       // Existing name includes ".suffix". Determine if substring after
1885       // ".suffix" is numeric and should be replaced with an incremented number.
1886       string after_suffix = name().substr(index + dot_suffix.size());
1887       if (after_suffix.empty()) {
1888         // Existing name ends in ".suffix". New name should end in ".suffix2".
1889         clone->name_ = name() + "2";
1890       } else {
1891         // If names ends with .suffix[0-9]+ then replace with a suffix with the
1892         // numeric value incremented.
1893         int64 numeric_suffix;
1894         if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
1895           clone->name_ =
1896               StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
1897         } else {
1898           // Substring after ".suffix" is non-numeric.
1899           clone->name_ = name() + dot_suffix;
1900         }
1901       }
1902     }
1903   }
1904   return clone;
1905 }
1906 
Clone(const string & suffix,HloCloneContext * context) const1907 std::unique_ptr<HloInstruction> HloInstruction::Clone(
1908     const string& suffix, HloCloneContext* context) const {
1909   std::unique_ptr<HloInstruction> clone =
1910       CloneWithNewShape(shape_, suffix, context);
1911   return clone;
1912 }
1913 
1914 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const1915 HloInstruction::LatestNonGteAncestorAndIndex() const {
1916   const HloInstruction* hlo = this;
1917   ShapeIndex index;
1918   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1919     index.push_back(hlo->tuple_index());
1920     hlo = hlo->operand(0);
1921   }
1922 
1923   // We built up index in the reverse order from what we want.
1924   std::reverse(index.begin(), index.end());
1925 
1926   return {hlo, index};
1927 }
1928 
LatestNonGteAncestor() const1929 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
1930   const HloInstruction* hlo = this;
1931   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1932     hlo = hlo->operand(0);
1933   }
1934   return hlo;
1935 }
1936 
operand(int64 i) const1937 const HloInstruction* HloInstruction::operand(int64 i) const {
1938   return operands_.at(i);
1939 }
1940 
mutable_operand(int64 i)1941 HloInstruction* HloInstruction::mutable_operand(int64 i) {
1942   CHECK(operands_[i] != nullptr);
1943   return operands_.at(i);
1944 }
1945 
operand_index(const HloInstruction * target) const1946 int64 HloInstruction::operand_index(const HloInstruction* target) const {
1947   for (int64 i = 0; i < operand_count(); ++i) {
1948     if (target == operand(i)) {
1949       return i;
1950     }
1951   }
1952   LOG(FATAL) << "target was not an operand: " << target->ToString();
1953 }
1954 
unique_operands() const1955 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
1956   InstructionVector unique;
1957   absl::flat_hash_set<const HloInstruction*> seen;
1958   for (HloInstruction* operand : operands()) {
1959     if (seen.insert(operand).second) {
1960       unique.push_back(operand);
1961     }
1962   }
1963   return unique;
1964 }
1965 
AddControlDependencyTo(HloInstruction * instruction)1966 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
1967   TF_RET_CHECK(instruction->parent() == parent());
1968   if (!absl::c_linear_search(control_successors_, instruction)) {
1969     control_successors_.push_back(instruction);
1970     TF_RET_CHECK(
1971         !absl::c_linear_search(instruction->control_predecessors_, this));
1972     instruction->control_predecessors_.push_back(this);
1973   }
1974   return Status::OK();
1975 }
1976 
RemoveControlDependencyTo(HloInstruction * instruction)1977 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
1978   TF_RET_CHECK(instruction->parent() == parent());
1979   TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
1980   TF_RETURN_IF_ERROR(
1981       EraseElementFromVector(&instruction->control_predecessors_, this));
1982   return Status::OK();
1983 }
1984 
DropAllControlDeps()1985 Status HloInstruction::DropAllControlDeps() {
1986   for (auto* ctrl_succ : control_successors_) {
1987     TF_RETURN_IF_ERROR(
1988         EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
1989   }
1990   for (auto* ctrl_pred : control_predecessors_) {
1991     TF_RETURN_IF_ERROR(
1992         EraseElementFromVector(&ctrl_pred->control_successors_, this));
1993   }
1994   control_successors_.clear();
1995   control_predecessors_.clear();
1996   return Status::OK();
1997 }
1998 
CopyAllControlDepsFrom(const HloInstruction * inst)1999 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
2000   for (auto* ctrl_pred : inst->control_predecessors()) {
2001     TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
2002   }
2003 
2004   for (auto* ctrl_succ : inst->control_successors()) {
2005     TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
2006   }
2007 
2008   return Status::OK();
2009 }
2010 
IdenticalInternal(const HloInstruction & other,const std::function<bool (const HloInstruction *,const HloInstruction *)> & eq_operands,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations,bool layout_sensitive,bool ignore_channel_id_values) const2011 bool HloInstruction::IdenticalInternal(
2012     const HloInstruction& other,
2013     const std::function<bool(const HloInstruction*, const HloInstruction*)>&
2014         eq_operands,
2015     const std::function<bool(const HloComputation*, const HloComputation*)>&
2016         eq_computations,
2017     bool layout_sensitive, bool ignore_channel_id_values) const {
2018   // An instruction is always identical to itself.
2019   if (this == &other) {
2020     return true;
2021   }
2022 
2023   // Identical instruction must have the same opcode, shape, and identical
2024   // operands.
2025   if (opcode() != other.opcode()) {
2026     return false;
2027   }
2028   if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
2029                          : ShapeUtil::Compatible(shape(), other.shape()))) {
2030     return false;
2031   }
2032   if (operands().size() != other.operands().size()) {
2033     return false;
2034   }
2035 
2036   // Use an explicit loop rather than ContainerEquals, because copying around
2037   // std::functions may be too expensive in some cases.
2038   for (size_t i = 0; i < operands().size(); ++i) {
2039     if (!eq_operands(operand(i), other.operand(i))) {
2040       return false;
2041     }
2042   }
2043 
2044   if (backend_config_ != other.backend_config_) {
2045     return false;
2046   }
2047 
2048   if (ignore_channel_id_values) {
2049     if (auto channel_inst = DynCast<HloChannelInstruction>(this)) {
2050       return channel_inst->IdenticalSlowPathIgnoringChannelIdValues(
2051           other, eq_computations);
2052     }
2053   }
2054   return IdenticalSlowPath(other, eq_computations);
2055 }
2056 
AppendOperand(HloInstruction * operand)2057 void HloInstruction::AppendOperand(HloInstruction* operand) {
2058   if (operand->parent() != nullptr) {
2059     DCHECK(!operand->parent()->IsMarkedAsDead(operand))
2060         << "Operand " << operand->name() << " is already marked dead";
2061   }
2062   operands_.push_back(operand);
2063   operand->AddUser(this);
2064 }
2065 
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)2066 void HloInstruction::RemoveOperandsAtAscendingIndices(
2067     absl::Span<const int> ascending_indices) {
2068   if (ascending_indices.empty()) {
2069     return;
2070   }
2071   int next_index = 0;
2072   int removed_count = 0;
2073   for (int to_remove : ascending_indices) {
2074     while (next_index < to_remove) {
2075       operands_[next_index - removed_count] = operands_[next_index];
2076       ++next_index;
2077     }
2078     CHECK_LT(to_remove, operands_.size());
2079     ++removed_count;
2080     ++next_index;
2081   }
2082   while (next_index < operands_.size()) {
2083     operands_[next_index - removed_count] = operands_[next_index];
2084     ++next_index;
2085   }
2086   CHECK_EQ(removed_count, ascending_indices.size());
2087   operands_.resize(operands_.size() - removed_count);
2088 }
2089 
AddUser(HloInstruction * user)2090 void HloInstruction::AddUser(HloInstruction* user) {
2091   if (!ContainsKey(user_map_, user)) {
2092     user_map_.emplace(user, users_.size());
2093     users_.push_back(user);
2094   }
2095 }
2096 
UserId(HloInstruction * user)2097 int64 HloInstruction::UserId(HloInstruction* user) {
2098   auto result = user_map_.find(user);
2099   CHECK(result != user_map_.end());
2100   return result->second;
2101 }
2102 
HasConstantOperand() const2103 bool HloInstruction::HasConstantOperand() const {
2104   for (const HloInstruction* operand : operands_) {
2105     if (operand->IsConstant()) {
2106       return true;
2107     }
2108   }
2109   return false;
2110 }
2111 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2112 bool HloInstruction::IdenticalSlowPath(
2113     const HloInstruction& other,
2114     const std::function<bool(const HloComputation*, const HloComputation*)>&
2115         eq_computations) const {
2116   // Perform opcode specific checks.
2117   switch (opcode()) {
2118     // The result of these instructions only depend upon their opcode and
2119     // operands.
2120     case HloOpcode::kAbs:
2121     case HloOpcode::kAtan2:
2122     case HloOpcode::kAdd:
2123     case HloOpcode::kBitcast:
2124     case HloOpcode::kBitcastConvert:
2125     case HloOpcode::kCeil:
2126     case HloOpcode::kClamp:
2127     case HloOpcode::kClz:
2128     case HloOpcode::kCollectivePermuteDone:
2129     case HloOpcode::kComplex:
2130     case HloOpcode::kConvert:
2131     case HloOpcode::kCopy:
2132     case HloOpcode::kCopyStart:
2133     case HloOpcode::kCopyDone:
2134     case HloOpcode::kCos:
2135     case HloOpcode::kDivide:
2136     case HloOpcode::kDynamicUpdateSlice:
2137     case HloOpcode::kExp:
2138     case HloOpcode::kExpm1:
2139     case HloOpcode::kFloor:
2140     case HloOpcode::kImag:
2141     case HloOpcode::kIsFinite:
2142     case HloOpcode::kLog:
2143     case HloOpcode::kLog1p:
2144     case HloOpcode::kAnd:
2145     case HloOpcode::kNot:
2146     case HloOpcode::kOr:
2147     case HloOpcode::kXor:
2148     case HloOpcode::kMaximum:
2149     case HloOpcode::kMinimum:
2150     case HloOpcode::kMultiply:
2151     case HloOpcode::kNegate:
2152     case HloOpcode::kPartitionId:
2153     case HloOpcode::kPopulationCount:
2154     case HloOpcode::kPower:
2155     case HloOpcode::kReal:
2156     case HloOpcode::kRemainder:
2157     case HloOpcode::kReshape:
2158     case HloOpcode::kDynamicReshape:
2159     case HloOpcode::kReplicaId:
2160     case HloOpcode::kRoundNearestAfz:
2161     case HloOpcode::kRsqrt:
2162     case HloOpcode::kSelect:
2163     case HloOpcode::kShiftLeft:
2164     case HloOpcode::kShiftRightArithmetic:
2165     case HloOpcode::kShiftRightLogical:
2166     case HloOpcode::kLogistic:
2167     case HloOpcode::kSign:
2168     case HloOpcode::kSin:
2169     case HloOpcode::kSqrt:
2170     case HloOpcode::kCbrt:
2171     case HloOpcode::kSubtract:
2172     case HloOpcode::kTanh:
2173     case HloOpcode::kTuple:
2174     case HloOpcode::kTupleSelect:
2175       return true;
2176 
2177     // This opcode has complex or special behavior so just return false.
2178     case HloOpcode::kAfterAll:
2179     case HloOpcode::kAddDependency:
2180       return false;
2181 
2182     // Remaining instructions with special values.
2183     case HloOpcode::kCall:
2184       return eq_computations(to_apply(), other.to_apply());
2185     case HloOpcode::kConditional:
2186       for (int j = 0; j < branch_count(); ++j) {
2187         if (!eq_computations(branch_computation(j),
2188                              other.branch_computation(j))) {
2189           return false;
2190         }
2191       }
2192       return true;
2193     case HloOpcode::kWhile:
2194       return (eq_computations(while_body(), other.while_body()) &&
2195               eq_computations(while_condition(), other.while_condition()));
2196 
2197     // Ops migrated to subclasses should never come to this line.
2198     // TODO(b/80131774): Remove this switch when migration is complete.
2199     case HloOpcode::kBatchNormTraining:
2200     case HloOpcode::kBatchNormInference:
2201     case HloOpcode::kBatchNormGrad:
2202     case HloOpcode::kFft:
2203     case HloOpcode::kCompare:
2204     case HloOpcode::kSend:
2205     case HloOpcode::kSendDone:
2206     case HloOpcode::kRecv:
2207     case HloOpcode::kRecvDone:
2208     case HloOpcode::kReverse:
2209     case HloOpcode::kConcatenate:
2210     case HloOpcode::kReduce:
2211     case HloOpcode::kSort:
2212     case HloOpcode::kTranspose:
2213     case HloOpcode::kBroadcast:
2214     case HloOpcode::kMap:
2215     case HloOpcode::kSlice:
2216     case HloOpcode::kConstant:
2217     case HloOpcode::kIota:
2218     case HloOpcode::kTrace:
2219     case HloOpcode::kFusion:
2220     case HloOpcode::kRng:
2221     case HloOpcode::kRngBitGenerator:
2222     case HloOpcode::kRngGetAndUpdateState:
2223     case HloOpcode::kParameter:
2224     case HloOpcode::kGetTupleElement:
2225     case HloOpcode::kReducePrecision:
2226     case HloOpcode::kInfeed:
2227     case HloOpcode::kOutfeed:
2228     case HloOpcode::kAllGather:
2229     case HloOpcode::kAllReduce:
2230     case HloOpcode::kAllToAll:
2231     case HloOpcode::kCollectivePermute:
2232     case HloOpcode::kCollectivePermuteStart:
2233     case HloOpcode::kConvolution:
2234     case HloOpcode::kCustomCall:
2235     case HloOpcode::kReduceWindow:
2236     case HloOpcode::kSelectAndScatter:
2237     case HloOpcode::kPad:
2238     case HloOpcode::kDynamicSlice:
2239     case HloOpcode::kGather:
2240     case HloOpcode::kScatter:
2241     case HloOpcode::kDot:
2242     case HloOpcode::kDomain:
2243     case HloOpcode::kGetDimensionSize:
2244     case HloOpcode::kSetDimensionSize:
2245     case HloOpcode::kTriangularSolve:
2246     case HloOpcode::kCholesky:
2247       LOG(FATAL) << "Base class impl called for opcode with subclass: "
2248                  << opcode();
2249   }
2250   return false;
2251 }
2252 
HashOperand(const HloInstruction * hlo)2253 static uint64 HashOperand(const HloInstruction* hlo) {
2254   return ShapeUtil::Hash(hlo->shape());
2255 }
2256 
Hash(const std::function<uint64 (const HloInstruction *)> & hash_operand) const2257 uint64 HloInstruction::Hash(
2258     const std::function<uint64(const HloInstruction*)>& hash_operand) const {
2259   using tensorflow::Hash64Combine;
2260 
2261   uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
2262   hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape()));
2263 
2264   if (!IsCrossModuleAllReduce()) {
2265     if (!operands().empty()) {
2266       for (size_t i = 0; i < operands().size(); ++i) {
2267         hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
2268       }
2269     }
2270   }
2271 
2272   hash_value = Hash64Combine(hash_value, InnerHash());
2273   return hash_value;
2274 }
2275 
Hash() const2276 uint64 HloInstruction::Hash() const {
2277   // Use HashOperand as an argument to prevent non-termination.
2278   return Hash(HashOperand);
2279 }
2280 
InnerHash() const2281 uint64 HloInstruction::InnerHash() const { return 13; }
2282 
RemoveUser(HloInstruction * user)2283 void HloInstruction::RemoveUser(HloInstruction* user) {
2284   auto map_it = user_map_.find(user);
2285   CHECK(map_it != user_map_.end());
2286 
2287   const int64 index = map_it->second;
2288   CHECK_EQ(users_[index], user);
2289 
2290   // Move the last user into the position of the removed user.
2291   users_[index] = users_.back();
2292   user_map_[users_.back()] = index;
2293 
2294   // Remove the user from the map and drop the last slot from the vector what
2295   // have been moved to the position of the original user.
2296   user_map_.erase(map_it);
2297   users_.pop_back();
2298 }
2299 
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)2300 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
2301                                       HloInstruction* new_producer) {
2302   TF_RET_CHECK(
2303       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2304       << "this shape: " << ShapeUtil::HumanString(shape())
2305       << ", replacement shape: "
2306       << ShapeUtil::HumanString(new_producer->shape());
2307   return ReplaceUseWithDifferentShape(user, new_producer);
2308 }
2309 
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)2310 Status HloInstruction::ReplaceUseWithDifferentShape(
2311     HloInstruction* user, HloInstruction* new_producer) {
2312   VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
2313           << " with " << new_producer->name();
2314 
2315   RemoveUser(user);
2316 
2317   TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
2318   std::replace(user->operands_.begin(), user->operands_.end(), this,
2319                new_producer);
2320   new_producer->AddUser(user);
2321   // Custom fusions may not be able to handle deduplicated operands.
2322   if (user->opcode() == HloOpcode::kFusion) {
2323     TF_RETURN_IF_ERROR(
2324         Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2325   }
2326   return Status::OK();
2327 }
2328 
ReplaceOperandWith(int64 operand_num,HloInstruction * new_operand)2329 Status HloInstruction::ReplaceOperandWith(int64 operand_num,
2330                                           HloInstruction* new_operand) {
2331   auto old_operand = operand(operand_num);
2332   TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
2333                                                         new_operand->shape()))
2334       << old_operand->shape() << " is not compatible with "
2335       << new_operand->shape();
2336   return ReplaceOperandWithDifferentShape(operand_num, new_operand);
2337 }
2338 
ReplaceOperandWithDifferentShape(int64 operand_num,HloInstruction * new_operand)2339 Status HloInstruction::ReplaceOperandWithDifferentShape(
2340     int64 operand_num, HloInstruction* new_operand) {
2341   TF_RET_CHECK(operand_num >= 0);
2342   TF_RET_CHECK(operand_num < operand_count());
2343   HloInstruction* old_operand = mutable_operand(operand_num);
2344   if (old_operand == new_operand) {
2345     return Status::OK();
2346   }
2347 
2348   operands_[operand_num] = new_operand;
2349 
2350   VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
2351           << new_operand->name() << ", was " << old_operand->name();
2352 
2353   if (!absl::c_linear_search(operands_, old_operand)) {
2354     old_operand->RemoveUser(this);
2355   }
2356   new_operand->AddUser(this);
2357   return Status::OK();
2358 }
2359 
ReplaceUsesWith(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2360 Status HloInstruction::ReplaceUsesWith(absl::Span<HloInstruction* const> users,
2361                                        HloInstruction* new_producer) {
2362   TF_RET_CHECK(
2363       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2364       << shape() << " is not compatible with " << new_producer->shape();
2365   return ReplaceAllUsesWithDifferentShape(users, new_producer);
2366 }
2367 
ReplaceAllUsesWithDifferentShape(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2368 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2369     absl::Span<HloInstruction* const> users, HloInstruction* new_producer) {
2370   for (HloInstruction* user : users) {
2371     TF_RETURN_IF_ERROR(ReplaceUseWithDifferentShape(user, new_producer));
2372   }
2373 
2374   if (parent_ && parent_->root_instruction() == this) {
2375     parent_->set_root_instruction(new_producer,
2376                                   /*accept_different_shape=*/true);
2377   }
2378   return Status::OK();
2379 }
2380 
ReplaceAllUsesWith(HloInstruction * new_producer)2381 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
2382   TF_RET_CHECK(
2383       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2384       << shape() << " is not compatible with " << new_producer->shape();
2385   return ReplaceAllUsesWithDifferentShape(new_producer);
2386 }
2387 
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)2388 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2389     HloInstruction* new_producer) {
2390   bool new_producer_is_user = false;
2391   for (HloInstruction* user : users()) {
2392     if (user == new_producer) {
2393       // It's possible that new_producer is a user of this instruction as might
2394       // be the case when replacing an instruction with a kCopy of itself. In
2395       // this case, don't do the replacement to avoid creating a cycle in the
2396       // graph. new_producer remains the only user of this instruction.
2397       new_producer_is_user = true;
2398     } else {
2399       std::replace(user->operands_.begin(), user->operands_.end(), this,
2400                    new_producer);
2401       new_producer->AddUser(user);
2402       if (user->opcode() == HloOpcode::kFusion) {
2403         TF_RETURN_IF_ERROR(
2404             Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2405       }
2406     }
2407   }
2408   users_.clear();
2409   user_map_.clear();
2410   if (new_producer_is_user) {
2411     AddUser(new_producer);
2412   }
2413   if (parent_ && parent_->root_instruction() == this) {
2414     parent_->set_root_instruction(new_producer,
2415                                   /*accept_different_shape=*/true);
2416   }
2417 
2418   return Status::OK();
2419 }
2420 
IsEffectiveBitcast() const2421 bool HloInstruction::IsEffectiveBitcast() const {
2422   return opcode_ == HloOpcode::kBitcast ||
2423          (opcode_ == HloOpcode::kTranspose &&
2424           ShapeUtil::TransposeIsBitcast(operand(0)->shape(), shape(),
2425                                         dimensions()));
2426 }
2427 
to_apply() const2428 HloComputation* HloInstruction::to_apply() const {
2429   switch (opcode_) {
2430     case HloOpcode::kCall:
2431     case HloOpcode::kMap:
2432     case HloOpcode::kReduceWindow:
2433     case HloOpcode::kReduce:
2434     case HloOpcode::kAllReduce:
2435     case HloOpcode::kScatter:
2436     case HloOpcode::kSort:
2437     case HloOpcode::kCustomCall:
2438       CHECK_EQ(called_computations_.size(), 1);
2439       return called_computations_[0];
2440     default:
2441       LOG(FATAL) << "Invalid opcode for to_apply(): "
2442                  << HloOpcodeString(opcode());
2443   }
2444 }
2445 
set_to_apply(HloComputation * computation)2446 void HloInstruction::set_to_apply(HloComputation* computation) {
2447   // Don't allow changing the computation for fused instructions so we don't
2448   // have to recompute called_instructions for the entire fusion instruction.
2449   CHECK(!IsFused());
2450   switch (opcode_) {
2451     case HloOpcode::kCall:
2452     case HloOpcode::kMap:
2453     case HloOpcode::kReduceWindow:
2454     case HloOpcode::kReduce:
2455     case HloOpcode::kAllReduce:
2456     case HloOpcode::kScatter:
2457     case HloOpcode::kSort:
2458     case HloOpcode::kCustomCall:
2459       CHECK_EQ(called_computations_.size(), 1);
2460       called_computations_[0] = computation;
2461       break;
2462     default:
2463       LOG(FATAL) << "Invalid opcode for to_apply(): "
2464                  << HloOpcodeString(opcode());
2465   }
2466 }
2467 
while_condition() const2468 HloComputation* HloInstruction::while_condition() const {
2469   CHECK_EQ(HloOpcode::kWhile, opcode_);
2470   return called_computations_[kConditionComputationIndex];
2471 }
2472 
while_body() const2473 HloComputation* HloInstruction::while_body() const {
2474   CHECK_EQ(HloOpcode::kWhile, opcode_);
2475   return called_computations_[kBodyComputationIndex];
2476 }
2477 
set_while_condition(HloComputation * computation)2478 void HloInstruction::set_while_condition(HloComputation* computation) {
2479   // Don't allow changing the computation for fused instructions so we don't
2480   // have to recompute called_instructions for the entire fusion instruction.
2481   CHECK(!IsFused());
2482   CHECK_EQ(HloOpcode::kWhile, opcode_);
2483   called_computations_[kConditionComputationIndex] = computation;
2484 }
2485 
set_while_body(HloComputation * computation)2486 void HloInstruction::set_while_body(HloComputation* computation) {
2487   // Don't allow changing the computation for fused instructions so we don't
2488   // have to recompute called_instructions for the entire fusion instruction.
2489   CHECK(!IsFused());
2490   CHECK_EQ(HloOpcode::kWhile, opcode_);
2491   called_computations_[kBodyComputationIndex] = computation;
2492 }
2493 
while_init() const2494 HloInstruction* HloInstruction::while_init() const {
2495   CHECK_EQ(HloOpcode::kWhile, opcode_);
2496   return operands_[0];
2497 }
2498 
true_computation() const2499 HloComputation* HloInstruction::true_computation() const {
2500   CHECK_EQ(HloOpcode::kConditional, opcode_);
2501   CHECK_EQ(PRED, operand(0)->shape().element_type());
2502   return called_computations_[kTrueComputationIndex];
2503 }
2504 
false_computation() const2505 HloComputation* HloInstruction::false_computation() const {
2506   CHECK_EQ(HloOpcode::kConditional, opcode_);
2507   CHECK_EQ(PRED, operand(0)->shape().element_type());
2508   return called_computations_[kFalseComputationIndex];
2509 }
2510 
branch_computations() const2511 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2512     const {
2513   CHECK(HloOpcode::kConditional == opcode_);
2514   return called_computations_;
2515 }
2516 
branch_count() const2517 int HloInstruction::branch_count() const {
2518   CHECK(HloOpcode::kConditional == opcode_);
2519   return called_computations_.size();
2520 }
2521 
branch_computation(int b) const2522 HloComputation* HloInstruction::branch_computation(int b) const {
2523   CHECK(HloOpcode::kConditional == opcode_);
2524   CHECK_GE(b, 0);
2525   CHECK_LT(b, called_computations_.size());
2526   return called_computations_[b];
2527 }
2528 
set_branch_computation(int b,HloComputation * computation)2529 void HloInstruction::set_branch_computation(int b,
2530                                             HloComputation* computation) {
2531   // Don't allow changing the computation for fused instructions so we don't
2532   // have to recompute called_instructions for the entire fusion instruction.
2533   CHECK(!IsFused());
2534   CHECK_EQ(HloOpcode::kConditional, opcode_);
2535   called_computations_[b] = computation;
2536 }
2537 
SignatureString() const2538 string HloInstruction::SignatureString() const {
2539   string operands =
2540       StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
2541         StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2542       });
2543   return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2544 }
2545 
PrintName(const string & name,bool print_ids)2546 string PrintName(const string& name, bool print_ids) {
2547   if (print_ids) {
2548     return name;
2549   } else {
2550     auto dot_position = name.find_first_of('.');
2551     return name.substr(0, dot_position);
2552   }
2553 }
2554 
2555 namespace {
2556 
2557 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2558 
PrintNameInternal(const string & name,const HloPrintOptions & options)2559 string PrintNameInternal(const string& name, const HloPrintOptions& options) {
2560   return StrCat(options.print_percent() ? "%" : "",
2561                 PrintName(name, options.print_ids()));
2562 }
2563 
PrintCycle(const HloInstruction * child,DFSStack * dfs_stack)2564 void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
2565   // This set contains HloInstructions from the top of `DFSStack` that might
2566   // belong to the cycle, i.e. if  DFSStack :=[back,...,child,...,top], then
2567   // `subgraph` := {child,...,top}.
2568   absl::flat_hash_set<const HloInstruction*> subgraph;
2569   while (!dfs_stack->empty() && dfs_stack->back().second != child) {
2570     subgraph.insert(dfs_stack->back().second);
2571     dfs_stack->pop_back();
2572   }
2573   // Start dfs at `child` and find a cycle with all nodes in `subgraph`.
2574   absl::flat_hash_set<const HloInstruction*> visited;
2575   absl::InlinedVector<const HloInstruction*, 16> dfs;
2576   dfs.push_back(child);
2577   while (!dfs.empty()) {
2578     bool found_next_instr = false;
2579     for (const auto& user : dfs.back()->users()) {
2580       if (user == child) {
2581         dfs.push_back(child);
2582         LOG(INFO) << "\n\nDirected cycle:\n  "
2583                   << absl::StrJoin(
2584                          dfs, "\n  ",
2585                          [](std::string* out, const HloInstruction* instr) {
2586                            out->append(instr->name());
2587                          });
2588         return;
2589       }
2590       if (!subgraph.contains(user) || visited.contains(user)) {
2591         continue;
2592       }
2593       visited.insert(user);
2594       dfs.push_back(user);
2595       found_next_instr = true;
2596     }
2597     if (!found_next_instr) {
2598       dfs.pop_back();
2599     }
2600   }
2601 }
2602 
2603 }  // namespace
2604 
ToString(const HloPrintOptions & options) const2605 string HloInstruction::ToString(const HloPrintOptions& options) const {
2606   CanonicalNameMap new_map;
2607   return ToStringWithCanonicalNameMap(options, &new_map);
2608 }
2609 
IsOpElementwise(HloOpcode opcode)2610 bool HloInstruction::IsOpElementwise(HloOpcode opcode) {
2611   switch (opcode) {
2612     // Unary elementwise operations.
2613     case HloOpcode::kAbs:
2614     case HloOpcode::kRoundNearestAfz:
2615     case HloOpcode::kCeil:
2616     case HloOpcode::kClz:
2617     case HloOpcode::kConvert:
2618     case HloOpcode::kBitcastConvert:
2619     case HloOpcode::kCopy:
2620     case HloOpcode::kCos:
2621     case HloOpcode::kExp:
2622     case HloOpcode::kExpm1:
2623     case HloOpcode::kFloor:
2624     case HloOpcode::kImag:
2625     case HloOpcode::kIsFinite:
2626     case HloOpcode::kLog:
2627     case HloOpcode::kLog1p:
2628     case HloOpcode::kNot:
2629     case HloOpcode::kNegate:
2630     case HloOpcode::kPopulationCount:
2631     case HloOpcode::kReal:
2632     case HloOpcode::kReducePrecision:
2633     case HloOpcode::kRsqrt:
2634     case HloOpcode::kLogistic:
2635     case HloOpcode::kSign:
2636     case HloOpcode::kSin:
2637     case HloOpcode::kSqrt:
2638     case HloOpcode::kCbrt:
2639     case HloOpcode::kTanh:
2640       return true;
2641 
2642     // Binary elementwise operations, the same as in IsElementwiseBinary().
2643     case HloOpcode::kAdd:
2644     case HloOpcode::kAtan2:
2645     case HloOpcode::kCompare:
2646     case HloOpcode::kComplex:
2647     case HloOpcode::kDivide:
2648     case HloOpcode::kMaximum:
2649     case HloOpcode::kMinimum:
2650     case HloOpcode::kMultiply:
2651     case HloOpcode::kPower:
2652     case HloOpcode::kRemainder:
2653     case HloOpcode::kSubtract:
2654     case HloOpcode::kAnd:
2655     case HloOpcode::kOr:
2656     case HloOpcode::kXor:
2657     case HloOpcode::kShiftLeft:
2658     case HloOpcode::kShiftRightArithmetic:
2659     case HloOpcode::kShiftRightLogical:
2660       return true;
2661 
2662     // Ternary elementwise operations.
2663     case HloOpcode::kSelect:
2664     case HloOpcode::kClamp:
2665       return true;
2666 
2667     default:
2668       return false;
2669   }
2670 }
2671 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2672 bool HloInstruction::IsElementwiseImpl(
2673     const absl::optional<int64>& operand_idx) const {
2674   if (opcode_ == HloOpcode::kDynamicUpdateSlice) {
2675     return operand_idx.has_value() && operand_idx.value() == 0;
2676   }
2677   return IsOpElementwise(opcode_);
2678 }
2679 
IsCrossModuleAllReduce() const2680 bool HloInstruction::IsCrossModuleAllReduce() const {
2681   return opcode() == HloOpcode::kAllReduce && channel_id();
2682 }
2683 
IsCrossReplicaAllReduce() const2684 bool HloInstruction::IsCrossReplicaAllReduce() const {
2685   return opcode() == HloOpcode::kAllReduce && !channel_id();
2686 }
2687 
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2688 string HloInstruction::ToStringWithCanonicalNameMap(
2689     const HloPrintOptions& options,
2690     CanonicalNameMap* canonical_name_map) const {
2691   string result = "";
2692 
2693   // Logic to print the instruction name (e.g. "%foo = ").
2694   if (options.canonicalize_instruction_names()) {
2695     if (options.is_in_nested_computation()) {
2696       // If we are canonicalizing instruction names and this is a top-level
2697       // HloInstruction::ToString() call, don't print an instruction name.
2698       StrAppend(&result,
2699                 PrintNameInternal(canonical_name_map->LookupOrInsert(name()),
2700                                   options),
2701                 " = ");
2702     }
2703   } else {
2704     StrAppend(&result, PrintNameInternal(name(), options), " = ");
2705   }
2706 
2707   if (options.print_result_shape()) {
2708     // Print shape.
2709     if (options.include_layout_in_shapes()) {
2710       StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ");
2711     } else {
2712       StrAppend(&result, ShapeUtil::HumanString(shape()), " ");
2713     }
2714   }
2715 
2716   // Print opcode, operand(s).
2717   StrAppend(&result, HloOpcodeString(opcode()), "(",
2718             OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
2719             ")");
2720 
2721   // Print additional attributes. If an instruction contains a subcomputation,
2722   // the subcomputation is also printed here.
2723   for (const string& extra : ExtraAttributesToString(options)) {
2724     StrAppend(&result, ", ", extra);
2725   }
2726 
2727   if (options.print_metadata() &&
2728       (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
2729        !metadata_.source_file().empty())) {
2730     StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
2731   }
2732   if (options.print_backend_config() && !backend_config_.empty()) {
2733     StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
2734   }
2735   return result;
2736 }
2737 
OperandsToString(const HloPrintOptions & options) const2738 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
2739   CanonicalNameMap new_map;
2740   return OperandsToStringWithCanonicalNameMap(options, &new_map);
2741 }
2742 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2743 string HloInstruction::OperandsToStringWithCanonicalNameMap(
2744     const HloPrintOptions& options,
2745     CanonicalNameMap* canonical_name_map) const {
2746   string operands;
2747   absl::Span<HloInstruction* const> slice(operands_);
2748   const int64 kMaxOperandsToShowIfCompact = 4;
2749   if (options.compact_operands() &&
2750       slice.size() > kMaxOperandsToShowIfCompact) {
2751     slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
2752   }
2753   operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
2754     // If operand is already been deleted, put `null` to the string output.
2755     if (operand == nullptr) {
2756       StrAppend(out, "null ");
2757       return;
2758     }
2759     std::vector<string> str;
2760     if (options.print_operand_shape()) {
2761       if (options.include_layout_in_shapes()) {
2762         str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
2763       } else {
2764         str.push_back(ShapeUtil::HumanString(operand->shape()));
2765       }
2766     }
2767 
2768     // In a top-level HloInstruction::ToString() call, the operand name is not
2769     // part of the canonical string.
2770     if (options.canonicalize_instruction_names() &&
2771         options.is_in_nested_computation()) {
2772       str.push_back(PrintNameInternal(
2773           canonical_name_map->LookupOrInsert(operand->name()), options));
2774     } else if (options.print_operand_names()) {
2775       str.push_back(PrintNameInternal(operand->name(), options));
2776     }
2777     StrAppend(out, StrJoin(str, " "));
2778   });
2779   const int64 remaining = operands_.size() - slice.size();
2780   if (slice.size() != operands_.size()) {
2781     StrAppend(&operands, ", ...(+", remaining, ")");
2782   }
2783   return operands;
2784 }
2785 
ExtraAttributesToString(const HloPrintOptions & options) const2786 std::vector<string> HloInstruction::ExtraAttributesToString(
2787     const HloPrintOptions& options) const {
2788   std::vector<string> extra = options.print_extra_attributes()
2789                                   ? ExtraAttributesToStringImpl(options)
2790                                   : std::vector<string>();
2791 
2792   if (options.print_subcomputation_mode() ==
2793       HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
2794     if (opcode() == HloOpcode::kWhile) {
2795       extra.push_back(StrCat(
2796           "condition=", PrintNameInternal(while_condition()->name(), options)));
2797       extra.push_back(
2798           StrCat("body=", PrintNameInternal(while_body()->name(), options)));
2799     } else if (opcode() == HloOpcode::kSelectAndScatter) {
2800       extra.push_back(
2801           StrCat("select=", PrintNameInternal(select()->name(), options)));
2802       extra.push_back(
2803           StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
2804     } else if (opcode() == HloOpcode::kConditional) {
2805       if (operand(0)->shape().element_type() == PRED) {
2806         extra.push_back(
2807             StrCat("true_computation=",
2808                    PrintNameInternal(true_computation()->name(), options)));
2809         extra.push_back(
2810             StrCat("false_computation=",
2811                    PrintNameInternal(false_computation()->name(), options)));
2812       } else {
2813         extra.push_back(StrCat(
2814             "branch_computations={",
2815             StrJoin(branch_computations(), ", ",
2816                     [&](string* out, const HloComputation* computation) {
2817                       StrAppend(
2818                           out, PrintNameInternal(computation->name(), options));
2819                     }),
2820             "}"));
2821       }
2822     } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
2823                opcode() == HloOpcode::kReduceWindow ||
2824                opcode() == HloOpcode::kReduce ||
2825                opcode() == HloOpcode::kAllReduce ||
2826                opcode() == HloOpcode::kScatter ||
2827                opcode() == HloOpcode::kSort) {
2828       extra.push_back(
2829           StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
2830     } else if (opcode() == HloOpcode::kCustomCall) {
2831       if (!called_computations().empty()) {
2832         extra.push_back(StrCat(
2833             "called_computations={",
2834             StrJoin(called_computations(), ", ",
2835                     [&](string* out, const HloComputation* computation) {
2836                       StrAppend(
2837                           out, PrintNameInternal(computation->name(), options));
2838                     }),
2839             "}"));
2840       }
2841     } else if (!called_computations().empty()) {
2842       extra.push_back(StrCat(
2843           "calls=",
2844           StrJoin(called_computations(), ", ",
2845                   [&](string* out, const HloComputation* computation) {
2846                     StrAppend(out,
2847                               PrintNameInternal(computation->name(), options));
2848                   })));
2849     }
2850   } else if (options.print_subcomputation_mode() ==
2851              HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
2852     HloPrintOptions new_options = options;
2853     new_options.set_is_in_nested_computation(true);
2854     switch (opcode()) {
2855       case HloOpcode::kWhile:
2856         extra.push_back(
2857             StrCat("condition=\n", while_condition()->ToString(new_options)));
2858         extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
2859         break;
2860       case HloOpcode::kSelectAndScatter:
2861         extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
2862         extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
2863         break;
2864       case HloOpcode::kConditional:
2865         if (operand(0)->shape().element_type() == PRED) {
2866           extra.push_back(StrCat("true_computation=\n",
2867                                  true_computation()->ToString(new_options)));
2868           extra.push_back(StrCat("false_computation=\n",
2869                                  false_computation()->ToString(new_options)));
2870         } else {
2871           extra.push_back(StrCat(
2872               "branch_computations={\n",
2873               StrJoin(branch_computations(), ",\n",
2874                       [&](string* out, const HloComputation* computation) {
2875                         StrAppend(out, computation->ToString(new_options));
2876                       }),
2877               "\n}"));
2878         }
2879         break;
2880       case HloOpcode::kCall:
2881       case HloOpcode::kMap:
2882       case HloOpcode::kReduceWindow:
2883       case HloOpcode::kReduce:
2884       case HloOpcode::kAllReduce:
2885       case HloOpcode::kScatter:
2886       case HloOpcode::kSort:
2887         extra.push_back(
2888             StrCat("to_apply=\n", to_apply()->ToString(new_options)));
2889         break;
2890       default:
2891         if (!called_computations().empty()) {
2892           extra.push_back(StrCat(
2893               "calls=\n",
2894               StrJoin(called_computations(), ", ",
2895                       [&](string* out, const HloComputation* computation) {
2896                         StrAppend(out, computation->ToString(new_options));
2897                       })));
2898         }
2899         break;
2900     }
2901   }
2902 
2903   if (has_sharding()) {
2904     extra.push_back(
2905         StrCat("sharding=", sharding().ToString(options.print_metadata())));
2906   }
2907   if (!frontend_attributes_.map().empty()) {
2908     extra.push_back(StrCat("frontend_attributes=",
2909                            FrontendAttributesToString(frontend_attributes_)));
2910   }
2911   if (!outer_dimension_partitions_.empty()) {
2912     extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
2913                                     StrJoin(outer_dimension_partitions_, ",")));
2914   }
2915 
2916   if (options.print_control_dependencies() && !control_predecessors_.empty()) {
2917     extra.push_back(StrCat("control-predecessors={",
2918                            StrJoin(control_predecessors_, ", ",
2919                                    [&](string* out, HloInstruction* pre) {
2920                                      StrAppend(out, PrintNameInternal(
2921                                                         pre->name(), options));
2922                                    }),
2923                            "}"));
2924   }
2925 
2926   return extra;
2927 }
2928 
ToShortString() const2929 string HloInstruction::ToShortString() const {
2930   return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
2931                 StrJoin(operands_, ", ",
2932                         [](string* out, HloInstruction* operand) {
2933                           StrAppend(out, "%", operand->name());
2934                         }),
2935                 ")");
2936 }
2937 
ToProto() const2938 HloInstructionProto HloInstruction::ToProto() const {
2939   HloInstructionProto proto;
2940   CHECK(unique_id_ != -1)
2941       << "This instruction does not have a valid id. Please make sure the "
2942          "instruction is inside a module before dumping it.";
2943   proto.set_id(unique_id_);
2944   proto.set_name(name_);
2945   proto.set_opcode(HloOpcodeString(opcode_));
2946   *proto.mutable_shape() = shape_.ToProto();
2947   for (const HloInstruction* operand : operands_) {
2948     proto.add_operand_ids(operand->unique_id());
2949   }
2950   for (const HloInstruction* control : control_predecessors_) {
2951     proto.add_control_predecessor_ids(control->unique_id());
2952   }
2953 
2954   *proto.mutable_metadata() = metadata_;
2955   proto.set_backend_config(backend_config_);
2956   if (opcode() != HloOpcode::kFusion) {
2957     for (const HloComputation* computation : called_computations_) {
2958       proto.add_called_computation_ids(computation->unique_id());
2959     }
2960   }
2961 
2962   if (has_sharding()) {
2963     *proto.mutable_sharding() = sharding().ToProto();
2964   }
2965   if (!outer_dimension_partitions_.empty()) {
2966     for (const auto& idx : outer_dimension_partitions_) {
2967       proto.mutable_outer_dimension_partitions()->Add(idx);
2968     }
2969   }
2970 
2971   *proto.mutable_frontend_attributes() = frontend_attributes_;
2972 
2973   return proto;
2974 }
2975 
ToCategory() const2976 string HloInstruction::ToCategory() const {
2977   if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
2978       opcode() == HloOpcode::kReshape ||
2979       opcode() == HloOpcode::kDynamicReshape) {
2980     return "data formatting";
2981   }
2982 
2983   if (IsElementwise()) {
2984     return "non-fusion elementwise";
2985   }
2986 
2987   return HloOpcodeString(opcode());
2988 }
2989 
tracing() const2990 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
2991 
set_tracing(HloInstruction * trace_instruction)2992 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
2993   trace_instruction_ = trace_instruction;
2994 }
2995 
IsFused() const2996 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
2997 
IsCustomCall(absl::string_view target) const2998 bool HloInstruction::IsCustomCall(absl::string_view target) const {
2999   return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
3000 }
3001 
IsInputFusion() const3002 bool HloInstruction::IsInputFusion() const {
3003   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
3004 }
3005 
IsLoopFusion() const3006 bool HloInstruction::IsLoopFusion() const {
3007   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop;
3008 }
3009 
IsOutputFusion() const3010 bool HloInstruction::IsOutputFusion() const {
3011   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput;
3012 }
3013 
IsCustomFusion() const3014 bool HloInstruction::IsCustomFusion() const {
3015   return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom;
3016 }
3017 
IsFusible() const3018 bool HloInstruction::IsFusible() const {
3019   // Instructions which are traced should not be fused.
3020   if (tracing()) {
3021     return false;
3022   }
3023   // Some kinds of instructions don't make sense to fuse.
3024   switch (opcode_) {
3025     case HloOpcode::kDomain:
3026     case HloOpcode::kParameter:
3027     case HloOpcode::kWhile:
3028     case HloOpcode::kConditional:
3029     case HloOpcode::kCall:
3030       return false;
3031     // Fusions are always fusible.
3032     case HloOpcode::kFusion:
3033     // Side effecting reduce and reduce window would be invalid HLO.
3034     case HloOpcode::kMap:
3035     case HloOpcode::kReduce:
3036     case HloOpcode::kReduceWindow:
3037       return true;
3038     case HloOpcode::kRng:
3039       return user_count() <= 1;
3040     // Side effecting instructions cannot be fused.
3041     default:
3042       return !HasSideEffect();
3043   }
3044 }
3045 
HloInstruction(HloOpcode opcode,const Shape & shape)3046 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
3047     : unique_id_(-1),
3048       opcode_(opcode),
3049       shape_(shape),
3050       name_(HloOpcodeString(opcode)),
3051       marked_as_dead_(false) {
3052   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
3053 }
3054 
3055 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)3056 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
3057   switch (opcode_) {
3058     case HloOpcode::kAbs:
3059       return visitor->HandleAbs(this);
3060     case HloOpcode::kAtan2:
3061       return visitor->HandleAtan2(this);
3062     case HloOpcode::kRoundNearestAfz:
3063       return visitor->HandleRound(this);
3064     case HloOpcode::kBatchNormTraining:
3065       return visitor->HandleBatchNormTraining(this);
3066     case HloOpcode::kBatchNormInference:
3067       return visitor->HandleBatchNormInference(this);
3068     case HloOpcode::kBatchNormGrad:
3069       return visitor->HandleBatchNormGrad(this);
3070     case HloOpcode::kLogistic:
3071       return visitor->HandleLogistic(this);
3072     case HloOpcode::kSign:
3073       return visitor->HandleSign(this);
3074     case HloOpcode::kConstant:
3075       return visitor->HandleConstant(this);
3076     case HloOpcode::kGetTupleElement:
3077       return visitor->HandleGetTupleElement(this);
3078     case HloOpcode::kParameter:
3079       return visitor->HandleParameter(this);
3080     case HloOpcode::kCompare:
3081       return visitor->HandleCompare(this);
3082     case HloOpcode::kComplex:
3083       return visitor->HandleComplex(this);
3084     case HloOpcode::kAdd:
3085       return visitor->HandleAdd(this);
3086     case HloOpcode::kDivide:
3087       return visitor->HandleDivide(this);
3088     case HloOpcode::kSubtract:
3089       return visitor->HandleSubtract(this);
3090     case HloOpcode::kMaximum:
3091       return visitor->HandleMaximum(this);
3092     case HloOpcode::kMinimum:
3093       return visitor->HandleMinimum(this);
3094     case HloOpcode::kAnd:
3095       return visitor->HandleAnd(this);
3096     case HloOpcode::kOr:
3097       return visitor->HandleOr(this);
3098     case HloOpcode::kXor:
3099       return visitor->HandleXor(this);
3100     case HloOpcode::kShiftLeft:
3101       return visitor->HandleShiftLeft(this);
3102     case HloOpcode::kShiftRightArithmetic:
3103       return visitor->HandleShiftRightArithmetic(this);
3104     case HloOpcode::kShiftRightLogical:
3105       return visitor->HandleShiftRightLogical(this);
3106     case HloOpcode::kConcatenate:
3107       return visitor->HandleConcatenate(this);
3108     case HloOpcode::kConvert:
3109       return visitor->HandleConvert(this);
3110     case HloOpcode::kBitcastConvert:
3111       return visitor->HandleBitcastConvert(this);
3112     case HloOpcode::kCopy:
3113       return visitor->HandleCopy(this);
3114     case HloOpcode::kMultiply:
3115       return visitor->HandleMultiply(this);
3116     case HloOpcode::kDot:
3117       return visitor->HandleDot(this);
3118     case HloOpcode::kPower:
3119       return visitor->HandlePower(this);
3120     case HloOpcode::kRemainder:
3121       return visitor->HandleRemainder(this);
3122     case HloOpcode::kSelect:
3123       return visitor->HandleSelect(this);
3124     case HloOpcode::kTupleSelect:
3125       return visitor->HandleTupleSelect(this);
3126     case HloOpcode::kConvolution:
3127       return visitor->HandleConvolution(this);
3128     case HloOpcode::kFft:
3129       return visitor->HandleFft(this);
3130     case HloOpcode::kAllGather:
3131       return visitor->HandleAllGather(this);
3132     case HloOpcode::kAllReduce:
3133       return visitor->HandleAllReduce(this);
3134     case HloOpcode::kAllToAll:
3135       return visitor->HandleAllToAll(this);
3136     case HloOpcode::kCollectivePermute:
3137       return visitor->HandleCollectivePermute(this);
3138     case HloOpcode::kCollectivePermuteStart:
3139       return visitor->HandleCollectivePermuteStart(this);
3140     case HloOpcode::kCollectivePermuteDone:
3141       return visitor->HandleCollectivePermuteDone(this);
3142     case HloOpcode::kReplicaId:
3143       return visitor->HandleReplicaId(this);
3144     case HloOpcode::kPartitionId:
3145       return visitor->HandlePartitionId(this);
3146     case HloOpcode::kTuple:
3147       return visitor->HandleTuple(this);
3148     case HloOpcode::kMap:
3149       return visitor->HandleMap(this);
3150     case HloOpcode::kClamp:
3151       return visitor->HandleClamp(this);
3152     case HloOpcode::kReduce:
3153       return visitor->HandleReduce(this);
3154     case HloOpcode::kReduceWindow:
3155       return visitor->HandleReduceWindow(this);
3156     case HloOpcode::kSelectAndScatter:
3157       return visitor->HandleSelectAndScatter(this);
3158     case HloOpcode::kNegate:
3159       return visitor->HandleNegate(this);
3160     case HloOpcode::kExp:
3161       return visitor->HandleExp(this);
3162     case HloOpcode::kExpm1:
3163       return visitor->HandleExpm1(this);
3164     case HloOpcode::kFloor:
3165       return visitor->HandleFloor(this);
3166     case HloOpcode::kCeil:
3167       return visitor->HandleCeil(this);
3168     case HloOpcode::kClz:
3169       return visitor->HandleClz(this);
3170     case HloOpcode::kLog:
3171       return visitor->HandleLog(this);
3172     case HloOpcode::kLog1p:
3173       return visitor->HandleLog1p(this);
3174     case HloOpcode::kTanh:
3175       return visitor->HandleTanh(this);
3176     case HloOpcode::kCos:
3177       return visitor->HandleCos(this);
3178     case HloOpcode::kSin:
3179       return visitor->HandleSin(this);
3180     case HloOpcode::kSqrt:
3181       return visitor->HandleSqrt(this);
3182     case HloOpcode::kCbrt:
3183       return visitor->HandleCbrt(this);
3184     case HloOpcode::kRsqrt:
3185       return visitor->HandleRsqrt(this);
3186     case HloOpcode::kReal:
3187       return visitor->HandleReal(this);
3188     case HloOpcode::kImag:
3189       return visitor->HandleImag(this);
3190     case HloOpcode::kIsFinite:
3191       return visitor->HandleIsFinite(this);
3192     case HloOpcode::kNot:
3193       return visitor->HandleNot(this);
3194     case HloOpcode::kPopulationCount:
3195       return visitor->HandlePopulationCount(this);
3196     case HloOpcode::kBitcast:
3197       return visitor->HandleBitcast(this);
3198     case HloOpcode::kBroadcast:
3199       return visitor->HandleBroadcast(this);
3200     case HloOpcode::kPad:
3201       return visitor->HandlePad(this);
3202     case HloOpcode::kReshape:
3203       return visitor->HandleReshape(this);
3204     case HloOpcode::kDynamicReshape:
3205       return visitor->HandleDynamicReshape(this);
3206     case HloOpcode::kTranspose:
3207       return visitor->HandleTranspose(this);
3208     case HloOpcode::kReverse:
3209       return visitor->HandleReverse(this);
3210     case HloOpcode::kReducePrecision:
3211       return visitor->HandleReducePrecision(this);
3212     case HloOpcode::kSlice:
3213       return visitor->HandleSlice(this);
3214     case HloOpcode::kDynamicSlice:
3215       return visitor->HandleDynamicSlice(this);
3216     case HloOpcode::kDynamicUpdateSlice:
3217       return visitor->HandleDynamicUpdateSlice(this);
3218     case HloOpcode::kSort:
3219       return visitor->HandleSort(this);
3220     case HloOpcode::kInfeed:
3221       return visitor->HandleInfeed(this);
3222     case HloOpcode::kOutfeed:
3223       return visitor->HandleOutfeed(this);
3224     case HloOpcode::kRng:
3225       return visitor->HandleRng(this);
3226     case HloOpcode::kRngBitGenerator:
3227       return visitor->HandleRngBitGenerator(this);
3228     case HloOpcode::kRngGetAndUpdateState:
3229       return visitor->HandleRngGetAndUpdateState(this);
3230     case HloOpcode::kWhile:
3231       return visitor->HandleWhile(this);
3232     case HloOpcode::kFusion:
3233       return visitor->HandleFusion(this);
3234     case HloOpcode::kCall:
3235       return visitor->HandleCall(this);
3236     case HloOpcode::kConditional:
3237       return visitor->HandleConditional(this);
3238     case HloOpcode::kCustomCall:
3239       return visitor->HandleCustomCall(this);
3240     case HloOpcode::kCopyStart:
3241       return visitor->HandleCopyStart(this);
3242     case HloOpcode::kCopyDone:
3243       return visitor->HandleCopyDone(this);
3244     case HloOpcode::kRecv:
3245       return visitor->HandleRecv(this);
3246     case HloOpcode::kRecvDone:
3247       return visitor->HandleRecvDone(this);
3248     case HloOpcode::kSend:
3249       return visitor->HandleSend(this);
3250     case HloOpcode::kSendDone:
3251       return visitor->HandleSendDone(this);
3252     case HloOpcode::kGather:
3253       return visitor->HandleGather(this);
3254     case HloOpcode::kScatter:
3255       return visitor->HandleScatter(this);
3256     case HloOpcode::kDomain:
3257       return visitor->HandleDomain(this);
3258     case HloOpcode::kAfterAll:
3259       return visitor->HandleAfterAll(this);
3260     case HloOpcode::kAddDependency:
3261       return visitor->HandleAddDependency(this);
3262     case HloOpcode::kIota:
3263       return visitor->HandleIota(this);
3264     case HloOpcode::kGetDimensionSize:
3265       return visitor->HandleGetDimensionSize(this);
3266     case HloOpcode::kSetDimensionSize:
3267       return visitor->HandleSetDimensionSize(this);
3268     case HloOpcode::kTriangularSolve:
3269       return visitor->HandleTriangularSolve(this);
3270     case HloOpcode::kCholesky:
3271       return visitor->HandleCholesky(this);
3272 
3273     // These opcodes are not handled here.
3274     case HloOpcode::kTrace:
3275       return Status::OK();
3276   }
3277   return InternalError(
3278       "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
3279       "please file a bug for XLA.",
3280       HloOpcodeString(opcode_));
3281 }
3282 
3283 // Explicit instantiations.
3284 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
3285 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
3286 
3287 // Push "child" onto the dfs_stack if not already visited.  Returns false if a
3288 // cycle was detected, and true otherwise.
3289 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)3290 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
3291                          HloInstruction* child) {
3292   CHECK(child != nullptr);
3293   const int id = child->unique_id();
3294   CHECK_GE(id, 0) << "instruction may not have a parent computation";
3295   switch (visitor->GetVisitState(id)) {
3296     case Visitor::kVisiting:
3297       return false;
3298 
3299     case Visitor::kVisited:
3300       // Nothing to do
3301       return true;
3302 
3303     case Visitor::kNotVisited:
3304       dfs_stack->push_back(std::make_pair(id, child));
3305       return true;
3306   }
3307 }
3308 
3309 using InternalCompareFunction =
3310     std::function<bool(std::pair<int, const HloInstruction*>,
3311                        std::pair<int, const HloInstruction*>)>;
3312 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)3313 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
3314                            const InternalCompareFunction* operand_order,
3315                            bool ignore_control_predecessors) {
3316   // Calculating the instruction count within a module can be expensive on large
3317   // models so only do it if the visit state is empty. This will help when the
3318   // same visitor is reused across many computations of a single module.
3319   if (visitor->VisitStateCapacity() == 0) {
3320     visitor->ReserveVisitStates(root->GetModule()->instruction_count());
3321   }
3322 
3323   // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
3324   //
3325   // We need to keep track of both the id and the instruction because
3326   // instructions can get deleted while they are on the stack, so we
3327   // can't always use the (potentially dead) instruction object to grab
3328   // its id.
3329   DFSStack dfs_stack;
3330   dfs_stack.emplace_back(root->unique_id(), root);
3331 
3332   do {
3333     DCHECK(!dfs_stack.empty());
3334 
3335     int current_id = dfs_stack.back().first;
3336     HloInstruction* current_node = dfs_stack.back().second;
3337     CHECK_GE(current_id, 0) << current_id << ": " << current_node
3338                             << ": instruction may not have parent computation";
3339     typename Visitor::VisitState visit_state =
3340         visitor->GetVisitState(current_id);
3341     if (visit_state == Visitor::kVisited) {
3342       dfs_stack.pop_back();
3343       VLOG(3) << "Not visiting HLO (id = " << current_id
3344               << ") as it was already visited.";
3345       continue;
3346     }
3347 
3348     if (visit_state == Visitor::kVisiting) {
3349       dfs_stack.pop_back();
3350 
3351       TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
3352       VLOG(2) << "Visiting HLO %" << current_node->name();
3353       TF_RETURN_IF_ERROR(current_node->Visit(visitor));
3354       visitor->SetVisitState(current_id, Visitor::kVisited);
3355       TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
3356       continue;
3357     }
3358 
3359     visitor->SetVisitState(current_id, Visitor::kVisiting);
3360 
3361     const size_t old_dfs_stack_size = dfs_stack.size();
3362     for (HloInstruction* child : current_node->operands()) {
3363       if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3364         PrintCycle(child, &dfs_stack);
3365         return FailedPrecondition(
3366             "A cycle is detected while visiting instruction %s",
3367             current_node->ToString());
3368       }
3369     }
3370 
3371     if (!ignore_control_predecessors) {
3372       for (HloInstruction* child : current_node->control_predecessors()) {
3373         if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3374           PrintCycle(child, &dfs_stack);
3375           return FailedPrecondition(
3376               "A cycle is detected while visiting instruction %s",
3377               current_node->ToString());
3378         }
3379       }
3380     }
3381 
3382     if (operand_order != nullptr) {
3383       std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
3384                 *operand_order);
3385     }
3386 
3387     // This makes the traversal order the same as what you'd expect
3388     // out of a recursive algorithm.
3389     std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
3390   } while (!dfs_stack.empty());
3391 
3392   return Status::OK();
3393 }
3394 
3395 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)3396 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
3397                               bool call_finish_visit,
3398                               bool ignore_control_predecessors) {
3399   VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
3400   TF_RETURN_IF_ERROR(
3401       PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
3402   if (call_finish_visit) {
3403     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3404   }
3405   return Status::OK();
3406 }
3407 
3408 // Explicit instantiations.
3409 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
3410 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
3411 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)3412 Status HloInstruction::AcceptWithOperandOrder(
3413     DfsHloVisitor* visitor, const CompareFunction& operand_order,
3414     bool call_finish_visit) {
3415   VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
3416   InternalCompareFunction func = [&operand_order](
3417                                      std::pair<int, const HloInstruction*> a,
3418                                      std::pair<int, const HloInstruction*> b) {
3419     // Call the client's comparison function on the actual HloInstruction*
3420     // objects (ignoring the internal ids we also have in our stack entries)
3421     return operand_order(a.second, b.second);
3422   };
3423   TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
3424                                   /*ignore_control_predecessors=*/false));
3425   if (call_finish_visit) {
3426     VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
3427     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3428     VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
3429   }
3430   VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
3431   return Status::OK();
3432 }
3433 
shape() const3434 const Shape& HloInstruction::shape() const { return shape_; }
3435 
OperandIndices(const HloInstruction * operand) const3436 absl::InlinedVector<int64, 4> HloInstruction::OperandIndices(
3437     const HloInstruction* operand) const {
3438   absl::InlinedVector<int64, 4> result;
3439   for (int64 i = 0; i < operand_count(); ++i) {
3440     if (this->operand(i) == operand) {
3441       result.push_back(i);
3442     }
3443   }
3444   return result;
3445 }
3446 
IsElementwiseBinary() const3447 bool HloInstruction::IsElementwiseBinary() const {
3448   return IsElementwise() && operand_count() == 2;
3449 }
3450 
IsElementwise() const3451 bool HloInstruction::IsElementwise() const {
3452   return IsElementwiseImpl(absl::nullopt);
3453 }
3454 
IsElementwiseOnOperand(int64 operand_idx) const3455 bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
3456   return IsElementwiseImpl(operand_idx);
3457 }
3458 
3459 // A helper class for memoized, recursive computation of HloOpcode::kFusion
3460 // in HloInstruction::OperandElementUse below.
3461 class HloInstruction::FusionReusesParamElements {
3462  public:
3463   using UseKind = HloInstruction::UseKind;
3464 
3465   // We could rather iterate backwards through fused_instructions_ here, as it
3466   // is in reverse postorder, and compute whether each fused instruction reuses
3467   // the value of this parameter, which would save stack space but not allow us
3468   // to finish early if we find a reuse.
Compute(int64 i,const HloInstruction & hlo)3469   static UseKind Compute(int64 i, const HloInstruction& hlo) {
3470     absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
3471     return ComputeInternal(i, hlo, &memoization_cache);
3472   }
3473 
3474  private:
ComputeInternal(int64 i,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)3475   static UseKind ComputeInternal(
3476       int64 i, const HloInstruction& hlo,
3477       absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
3478     if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
3479       if (hlo_param->parameter_number() == i) {
3480         return UseKind::kUse;
3481       }
3482     }
3483 
3484     auto p = cache->emplace(&hlo, UseKind::kNoUse);
3485     auto value_it = p.first;
3486     const bool key_is_new = p.second;
3487 
3488     if (key_is_new) {
3489       for (int64 j = 0; j < hlo.operands_.size(); ++j) {
3490         UseKind old_val = value_it->second;
3491 
3492         // The next operation invalidates iterators.
3493         UseKind new_val =
3494             Fold(old_val,
3495                  FoldUseMandatory(hlo.OperandElementUse(j),
3496                                   ComputeInternal(i, *hlo.operand(j), cache)));
3497 
3498         // Re-acquire the iterator. We could work harder to do this only if
3499         // absolutely necessary, but this code is not hot enough to warrant
3500         // that.
3501         value_it = cache->find(&hlo);
3502         value_it->second = new_val;
3503         // Fold() minimizes the UseKind value. If it is already minimum, we can
3504         // break the loop early.
3505         if (new_val == UseKind::kReuse) {
3506           break;
3507         }
3508       }
3509     }
3510     return value_it->second;
3511   }
3512 
3513   // Combines two UseKinds.
3514   //
3515   // This is the min operation on the lattice
3516   //
3517   //   kReuse < kUse < kNoUse.
3518   //
3519   // Two kUses uses which have different permutations count as kReuse.
Fold(UseKind a,UseKind b)3520   static UseKind Fold(UseKind a, UseKind b) {
3521     // Without loss of generality, let `b` be the operation with the larger use
3522     // kind.
3523     if (b.kind < a.kind) {
3524       std::swap(a, b);
3525     }
3526     // If the kinds are different, return the smaller one, namely `a`.
3527     if (a.kind != b.kind) {
3528       return a;
3529     }
3530     // If the kinds are both kUse, check that they're the same permutation.
3531     if (a.kind == UseKind::kUse && b.kind == UseKind::kUse &&
3532         a.permutation_instr != b.permutation_instr) {
3533       return UseKind::kReuse;
3534     }
3535     return a;  // They're the same.
3536   }
3537 
3538   // Combines two UseKinds differently than Fold().
3539   //
3540   // This is the min operation on the lattice
3541   //
3542   //   kNoUse < kReuse < kUse.
3543   //
3544   // If `a` and `b` are both kUse and one has a non-null permutation
3545   // instruction, returns kUse with that permutation.  OTOH if both have
3546   // different, non-null permutation instructions, returns kReuse.
3547   //
3548   // You can think of this sort of as a conjunction, whereas Fold is sort of a
3549   // disjunction.  FoldUseMandatory() says "no use" if either input isn't used,
3550   // whereas Fold() would say "use".
FoldUseMandatory(UseKind a,UseKind b)3551   static UseKind FoldUseMandatory(UseKind a, UseKind b) {
3552     if (a.kind == UseKind::kNoUse || b.kind == UseKind::kNoUse) {
3553       return UseKind::kNoUse;
3554     }
3555     if (a.kind == UseKind::kReuse || b.kind == UseKind::kReuse) {
3556       return UseKind::kReuse;
3557     }
3558     if (a.permutation_instr == b.permutation_instr) {
3559       return a;  // They're the same.
3560     }
3561     if (b.permutation_instr == nullptr) {
3562       return a;
3563     }
3564     if (a.permutation_instr == nullptr) {
3565       return b;
3566     }
3567     return UseKind::kReuse;
3568   }
3569 };
3570 
OperandElementUse(int64 operand_num) const3571 HloInstruction::UseKind HloInstruction::OperandElementUse(
3572     int64 operand_num) const {
3573   switch (opcode_) {
3574     case HloOpcode::kBitcast:
3575       // A bitcast that only adds or removes degenerate (i.e. size 1) dimensions
3576       // doesn't permute its elements, so it counts as a plain, non-permuting
3577       // use.
3578       return ShapeUtil::DropDegenerateDimensions(shape()) ==
3579                      ShapeUtil::DropDegenerateDimensions(operand(0)->shape())
3580                  ? UseKind::kUse
3581                  : UseKind::Permuting(this);
3582     case HloOpcode::kConcatenate:
3583     case HloOpcode::kReshape:
3584     case HloOpcode::kReverse:
3585     case HloOpcode::kSlice:
3586     case HloOpcode::kTranspose:
3587       return UseKind::Permuting(this);
3588     case HloOpcode::kPad:
3589       // Pad reuses the padding value but not the padded array elements.
3590       return operand_num > 0 ? UseKind::kReuse : UseKind::Permuting(this);
3591     case HloOpcode::kReduce:
3592       // Reduce reuses the init values but not the operand array elements.
3593       return operand_num >= Cast<HloReduceInstruction>(this)->input_count()
3594                  ? UseKind::kReuse
3595                  : UseKind::Permuting(this);
3596     case HloOpcode::kFusion:
3597       // Uses the memoizing, recursive computation defined above.
3598       return FusionReusesParamElements::Compute(operand_num,
3599                                                 *fused_expression_root());
3600     case HloOpcode::kDot:
3601       // Matrix-vector dots do not reuse the matrix operand.
3602       if (shape().dimensions_size() <= 1) {
3603         if ((operand_num == 0 && operand(1)->shape().rank() <= 1) ||
3604             (operand_num == 1 && operand(0)->shape().rank() <= 1)) {
3605           return UseKind::kUse;
3606         }
3607       }
3608       return UseKind::kReuse;
3609     case HloOpcode::kDynamicUpdateSlice:
3610       // Dynamic-update-slice reuses only start_indices.
3611       if (operand_num == 0 || operand_num == 1) {
3612         return UseKind::kUse;
3613       }
3614       return UseKind::kReuse;
3615     case HloOpcode::kGather:
3616       // Gather reads its indices in a linear fashion, and it permutes the
3617       // vector it's gathering from.
3618       return operand_num == 0 ? UseKind::kUse : UseKind::Permuting(this);
3619     default:
3620       return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
3621   }
3622 }
3623 
3624 std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const3625 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
3626   if (HloOpcode::kReshape != opcode_) {
3627     return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
3628   }
3629   return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
3630                                                       shape_);
3631 }
3632 
ToString(HloInstruction::FusionKind kind)3633 string ToString(HloInstruction::FusionKind kind) {
3634   switch (kind) {
3635     case HloInstruction::FusionKind::kLoop:
3636       return "kLoop";
3637     case HloInstruction::FusionKind::kInput:
3638       return "kInput";
3639     case HloInstruction::FusionKind::kOutput:
3640       return "kOutput";
3641     case HloInstruction::FusionKind::kCustom:
3642       return "kCustom";
3643   }
3644 }
3645 
StringToFusionKind(const string & kind_name)3646 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
3647     const string& kind_name) {
3648   if (kind_name == "kLoop") {
3649     return HloInstruction::FusionKind::kLoop;
3650   }
3651   if (kind_name == "kInput") {
3652     return HloInstruction::FusionKind::kInput;
3653   }
3654   if (kind_name == "kOutput") {
3655     return HloInstruction::FusionKind::kOutput;
3656   }
3657   if (kind_name == "kCustom") {
3658     return HloInstruction::FusionKind::kCustom;
3659   }
3660   return InvalidArgument("Unknown fusion kind: %s", kind_name);
3661 }
3662 
FrontendAttributesToString(const FrontendAttributes & frontend_attributes)3663 string FrontendAttributesToString(
3664     const FrontendAttributes& frontend_attributes) {
3665   std::vector<std::pair<string, string>> sorted_attributes(
3666       frontend_attributes.map().begin(), frontend_attributes.map().end());
3667   absl::c_sort(sorted_attributes);
3668   // Frontend attribute is a comma-separated list of attribute="value" pairs,
3669   // e.g., frontend_attributes={name="value_a",type="int32"}.
3670   const auto formatter = [](string* out,
3671                             const std::pair<string, string>& item) {
3672     absl::StrAppend(out, item.first, "=\"", item.second, "\"");
3673   };
3674   return absl::StrFormat("{%s}",
3675                          absl::StrJoin(sorted_attributes, ",", formatter));
3676 }
3677 
PaddingConfigToString(const PaddingConfig & padding)3678 string PaddingConfigToString(const PaddingConfig& padding) {
3679   bool has_interior_padding =
3680       absl::c_any_of(padding.dimensions(),
3681                      [](const PaddingConfig::PaddingConfigDimension& dim) {
3682                        return dim.interior_padding() != 0;
3683                      });
3684   return StrJoin(
3685       padding.dimensions(), "x",
3686       [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
3687         StrAppend(
3688             out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
3689             has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
3690       });
3691 }
3692 
RandomDistributionToString(const RandomDistribution & distribution)3693 string RandomDistributionToString(const RandomDistribution& distribution) {
3694   return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
3695 }
RandomAlgorithmToString(const RandomAlgorithm & algorithm)3696 string RandomAlgorithmToString(const RandomAlgorithm& algorithm) {
3697   return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm));
3698 }
3699 
PrecisionToString(const PrecisionConfig::Precision & precision)3700 string PrecisionToString(const PrecisionConfig::Precision& precision) {
3701   return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
3702 }
3703 
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)3704 string ConvolutionDimensionNumbersToString(
3705     const ConvolutionDimensionNumbers& dnums) {
3706   // lhs_dims[i] is the symbol of the logical dimension i for the lhs
3707   // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
3708   std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
3709   lhs_dims[dnums.input_batch_dimension()] = 'b';
3710   lhs_dims[dnums.input_feature_dimension()] = 'f';
3711   for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
3712     lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
3713   }
3714 
3715   std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
3716   rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
3717   rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
3718   for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
3719     rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
3720   }
3721 
3722   std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
3723   output_dims[dnums.output_batch_dimension()] = 'b';
3724   output_dims[dnums.output_feature_dimension()] = 'f';
3725   for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
3726     output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
3727   }
3728 
3729   return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
3730                 StrJoin(output_dims, ""));
3731 }
3732 
ReplicaGroupsToString(const std::vector<ReplicaGroup> & replica_groups)3733 string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups) {
3734   std::vector<string> replica_group_str;
3735   replica_group_str.reserve(replica_groups.size());
3736   for (const ReplicaGroup& group : replica_groups) {
3737     replica_group_str.push_back(
3738         StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
3739   }
3740   return StrCat("{", StrJoin(replica_group_str, ","), "}");
3741 }
3742 
StringToRandomAlgorithm(const string & name)3743 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name) {
3744   static std::unordered_map<string, RandomAlgorithm>* map = [] {
3745     static auto* map = new std::unordered_map<string, RandomAlgorithm>;
3746     for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) {
3747       if (RandomAlgorithm_IsValid(i)) {
3748         auto value = static_cast<RandomAlgorithm>(i);
3749         (*map)[RandomAlgorithmToString(value)] = value;
3750       }
3751     }
3752     return map;
3753   }();
3754   auto found = map->find(absl::AsciiStrToLower(name));
3755   if (found == map->end()) {
3756     return InvalidArgument("Unknown algorithm");
3757   }
3758   return found->second;
3759 }
3760 
StringToRandomDistribution(const string & name)3761 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
3762   static std::unordered_map<string, RandomDistribution>* map = [] {
3763     static auto* map = new std::unordered_map<string, RandomDistribution>;
3764     for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
3765       if (RandomDistribution_IsValid(i)) {
3766         auto value = static_cast<RandomDistribution>(i);
3767         (*map)[RandomDistributionToString(value)] = value;
3768       }
3769     }
3770     return map;
3771   }();
3772   auto found = map->find(absl::AsciiStrToLower(name));
3773   if (found == map->end()) {
3774     return InvalidArgument("Unknown distribution");
3775   }
3776   return found->second;
3777 }
3778 
StringToPrecision(const string & name)3779 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
3780   static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
3781     static auto* map =
3782         new std::unordered_map<string, PrecisionConfig::Precision>;
3783     for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
3784       if (PrecisionConfig::Precision_IsValid(i)) {
3785         auto value = static_cast<PrecisionConfig::Precision>(i);
3786         (*map)[PrecisionToString(value)] = value;
3787       }
3788     }
3789     return map;
3790   }();
3791   auto found = map->find(absl::AsciiStrToLower(name));
3792   if (found == map->end()) {
3793     return InvalidArgument("Unknown distribution");
3794   }
3795   return found->second;
3796 }
3797 
operator <<(std::ostream & os,HloInstruction::FusionKind kind)3798 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
3799   return os << ToString(kind);
3800 }
3801 
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const3802 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
3803                                   const HloInstruction* const& rhs) const {
3804   if (rhs == nullptr) {
3805     // Nothing compares less than nullptr.
3806     return false;
3807   }
3808   if (lhs == nullptr) {
3809     return true;
3810   }
3811   auto lhs_module = lhs->GetModule();
3812   auto rhs_module = rhs->GetModule();
3813   CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
3814         (lhs_module != nullptr && rhs_module != nullptr));
3815   if (lhs_module != nullptr &&
3816       lhs_module->unique_id() != rhs_module->unique_id()) {
3817     return lhs_module->unique_id() < rhs_module->unique_id();
3818   }
3819   return lhs->unique_id() < rhs->unique_id();
3820 }
3821 
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const3822 Status HloInstruction::GetBackendConfigInternal(
3823     tensorflow::protobuf::Message* proto) const {
3824   proto->Clear();
3825 
3826   // Empty string does not parse as valid JSON, but it's a valid backend config,
3827   // corresponding to the empty proto.
3828   if (backend_config_.empty()) {
3829     return Status::OK();
3830   }
3831   return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
3832 }
3833 
set_backend_config(const tensorflow::protobuf::Message & proto)3834 Status HloInstruction::set_backend_config(
3835     const tensorflow::protobuf::Message& proto) {
3836   TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
3837   return Status::OK();
3838 }
3839 
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)3840 /* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
3841     const tensorflow::protobuf::Message& proto) {
3842   string ret;
3843   // Pass ignore_accuracy_loss = true because estimated_cycles field can be
3844   // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles =
3845   // INT64_MAX, JsonFormat will return an error status, although there is no
3846   // accuracy loss for int64.
3847   TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(
3848       proto, &ret, /*ignore_accuracy_loss=*/true));
3849   return ret;
3850 }
3851 
precision_config() const3852 const PrecisionConfig& HloInstruction::precision_config() const {
3853   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3854     return convolution->precision_config();
3855   }
3856   if (auto* dot = DynCast<HloDotInstruction>(this)) {
3857     return dot->precision_config();
3858   }
3859 
3860   if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
3861     return custom_call->precision_config();
3862   }
3863   LOG(FATAL) << "Unimplemented method.";
3864 }
3865 
mutable_precision_config()3866 PrecisionConfig* HloInstruction::mutable_precision_config() {
3867   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3868     return convolution->mutable_precision_config();
3869   }
3870   if (auto* dot = DynCast<HloDotInstruction>(this)) {
3871     return dot->mutable_precision_config();
3872   }
3873   LOG(FATAL) << "Unimplemented method.";
3874 }
3875 
GetModule() const3876 HloModule* HloInstruction::GetModule() const {
3877   if (parent_) {
3878     return parent_->parent();
3879   }
3880   return nullptr;
3881 }
3882 
UniquifyName(NameUniquer * name_uniquer)3883 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
3884   string parent_str = parent() == nullptr ? "noparent" : parent()->name();
3885   name_ = name_uniquer->GetUniqueName(name_);
3886 }
3887 
set_outer_dimension_partitions(const std::vector<int64> & outer_dimension_partitions)3888 void HloInstruction::set_outer_dimension_partitions(
3889     const std::vector<int64>& outer_dimension_partitions) {
3890   outer_dimension_partitions_ = outer_dimension_partitions;
3891 }
3892 
3893 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const3894 int64 HloInstruction::feature_index() const {
3895   return Cast<HloBatchNormInstruction>(this)->feature_index();
3896 }
3897 
epsilon() const3898 float HloInstruction::epsilon() const {
3899   return Cast<HloBatchNormInstruction>(this)->epsilon();
3900 }
3901 
fft_type() const3902 FftType HloInstruction::fft_type() const {
3903   return Cast<HloFftInstruction>(this)->fft_type();
3904 }
3905 
fft_length() const3906 const std::vector<int64>& HloInstruction::fft_length() const {
3907   return Cast<HloFftInstruction>(this)->fft_length();
3908 }
3909 
concatenate_dimension() const3910 int64 HloInstruction::concatenate_dimension() const {
3911   return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
3912 }
3913 
dimension() const3914 int64 HloInstruction::dimension() const {
3915   if (auto set_size = DynCast<HloSetDimensionSizeInstruction>(this)) {
3916     return set_size->dimension();
3917   }
3918   return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
3919 }
3920 
inferred_dimension() const3921 int64 HloInstruction::inferred_dimension() const {
3922   return Cast<HloReshapeInstruction>(this)->inferred_dimension();
3923 }
3924 
IsRank2Transpose() const3925 bool HloInstruction::IsRank2Transpose() const {
3926   auto transpose = DynCast<HloTransposeInstruction>(this);
3927   return transpose != nullptr && transpose->IsRank2Transpose();
3928 }
3929 
slice_starts(int64 dimension) const3930 int64 HloInstruction::slice_starts(int64 dimension) const {
3931   return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
3932 }
3933 
slice_starts() const3934 const std::vector<int64>& HloInstruction::slice_starts() const {
3935   return Cast<HloSliceInstruction>(this)->slice_starts();
3936 }
3937 
mutable_slice_starts()3938 std::vector<int64>* HloInstruction::mutable_slice_starts() {
3939   return Cast<HloSliceInstruction>(this)->mutable_slice_starts();
3940 }
3941 
slice_limits(int64 dimension) const3942 int64 HloInstruction::slice_limits(int64 dimension) const {
3943   return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
3944 }
3945 
slice_limits() const3946 const std::vector<int64>& HloInstruction::slice_limits() const {
3947   return Cast<HloSliceInstruction>(this)->slice_limits();
3948 }
3949 
mutable_slice_limits()3950 std::vector<int64>* HloInstruction::mutable_slice_limits() {
3951   return Cast<HloSliceInstruction>(this)->mutable_slice_limits();
3952 }
3953 
slice_strides(int64 dimension) const3954 int64 HloInstruction::slice_strides(int64 dimension) const {
3955   return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
3956 }
3957 
slice_strides() const3958 const std::vector<int64>& HloInstruction::slice_strides() const {
3959   return Cast<HloSliceInstruction>(this)->slice_strides();
3960 }
3961 
mutable_slice_strides()3962 std::vector<int64>* HloInstruction::mutable_slice_strides() {
3963   return Cast<HloSliceInstruction>(this)->mutable_slice_strides();
3964 }
3965 
literal() const3966 const Literal& HloInstruction::literal() const {
3967   return Cast<HloConstantInstruction>(this)->literal();
3968 }
3969 
IsConstant() const3970 bool HloInstruction::IsConstant() const {
3971   return DynCast<HloConstantInstruction>(this) != nullptr;
3972 }
3973 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)3974 void HloInstruction::RelayoutConstant(const Layout& new_layout,
3975                                       const ShapeIndex& shape_index) {
3976   Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
3977 }
3978 
TracingTag() const3979 string HloInstruction::TracingTag() const {
3980   return Cast<HloTraceInstruction>(this)->TracingTag();
3981 }
3982 
AddFusionOperand(HloInstruction * new_operand)3983 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
3984   return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
3985 }
3986 
3987 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)3988 void HloInstruction::MergeFusionInstruction(
3989     HloInstruction* instruction_to_merge) {
3990   return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
3991       Cast<HloFusionInstruction>(instruction_to_merge));
3992 }
3993 
3994 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)3995 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
3996     HloInstruction* instruction_to_merge) {
3997   return Cast<HloFusionInstruction>(this)
3998       ->MergeFusionInstructionIntoMultiOutput(
3999           Cast<HloFusionInstruction>(instruction_to_merge));
4000 }
4001 
FuseInstruction(HloInstruction * instruction_to_fuse)4002 HloInstruction* HloInstruction::FuseInstruction(
4003     HloInstruction* instruction_to_fuse) {
4004   return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
4005 }
4006 
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)4007 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
4008     HloInstruction* instruction_to_fuse) {
4009   return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
4010       instruction_to_fuse);
4011 }
4012 
fused_instructions_computation() const4013 HloComputation* HloInstruction::fused_instructions_computation() const {
4014   return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
4015 }
4016 
fused_expression_root() const4017 HloInstruction* HloInstruction::fused_expression_root() const {
4018   return Cast<HloFusionInstruction>(this)->fused_expression_root();
4019 }
4020 
4021 const tensorflow::gtl::iterator_range<UnwrappingIterator<
4022     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const4023 HloInstruction::fused_instructions() const {
4024   return Cast<HloFusionInstruction>(this)->fused_instructions();
4025 }
4026 
4027 const tensorflow::gtl::iterator_range<
4028     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()4029 HloInstruction::fused_instructions() {
4030   return Cast<HloFusionInstruction>(this)->fused_instructions();
4031 }
4032 
fused_instruction_count() const4033 int64 HloInstruction::fused_instruction_count() const {
4034   return Cast<HloFusionInstruction>(this)->fused_instruction_count();
4035 }
4036 
fused_parameter(int64 parameter_number) const4037 HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
4038   return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
4039 }
4040 
fused_parameters() const4041 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
4042   return Cast<HloFusionInstruction>(this)->fused_parameters();
4043 }
4044 
IsMultiOutputFusion() const4045 const bool HloInstruction::IsMultiOutputFusion() const {
4046   const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
4047   return fusion != nullptr && fusion->IsMultiOutputFusion();
4048 }
4049 
fusion_kind() const4050 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
4051   return Cast<HloFusionInstruction>(this)->fusion_kind();
4052 }
4053 
set_fusion_kind(FusionKind kind)4054 void HloInstruction::set_fusion_kind(FusionKind kind) {
4055   return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
4056 }
4057 
random_distribution() const4058 RandomDistribution HloInstruction::random_distribution() const {
4059   return Cast<HloRngInstruction>(this)->random_distribution();
4060 }
4061 
parameter_number() const4062 int64 HloInstruction::parameter_number() const {
4063   return Cast<HloParameterInstruction>(this)->parameter_number();
4064 }
4065 
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)4066 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4067     absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
4068   return Cast<HloParameterInstruction>(this)
4069       ->set_parameter_replicated_at_leaf_buffers(
4070           parameter_replicated_at_leaf_buffers);
4071 }
4072 
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)4073 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4074     const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
4075   return Cast<HloParameterInstruction>(this)
4076       ->set_parameter_replicated_at_leaf_buffers(
4077           parameter_replicated_at_leaf_buffers);
4078 }
4079 
4080 const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const4081 HloInstruction::parameter_replicated_at_leaf_buffers() const {
4082   return Cast<HloParameterInstruction>(this)
4083       ->parameter_replicated_at_leaf_buffers();
4084 }
4085 
tuple_index() const4086 int64 HloInstruction::tuple_index() const {
4087   return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
4088 }
4089 
set_tuple_index(int64 new_tuple_index)4090 void HloInstruction::set_tuple_index(int64 new_tuple_index) {
4091   return Cast<HloGetTupleElementInstruction>(this)->set_tuple_index(
4092       new_tuple_index);
4093 }
4094 
exponent_bits() const4095 int32 HloInstruction::exponent_bits() const {
4096   return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
4097 }
4098 
mantissa_bits() const4099 int32 HloInstruction::mantissa_bits() const {
4100   return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
4101 }
4102 
infeed_config() const4103 string HloInstruction::infeed_config() const {
4104   return Cast<HloInfeedInstruction>(this)->infeed_config();
4105 }
4106 
set_infeed_config(const string & config)4107 void HloInstruction::set_infeed_config(const string& config) {
4108   return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
4109 }
4110 
outfeed_shape() const4111 const Shape& HloInstruction::outfeed_shape() const {
4112   return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
4113 }
4114 
mutable_outfeed_shape()4115 Shape* HloInstruction::mutable_outfeed_shape() {
4116   return Cast<HloOutfeedInstruction>(this)->mutable_outfeed_shape();
4117 }
4118 
outfeed_config() const4119 const string& HloInstruction::outfeed_config() const {
4120   return Cast<HloOutfeedInstruction>(this)->outfeed_config();
4121 }
4122 
set_outfeed_config(const string & config)4123 void HloInstruction::set_outfeed_config(const string& config) {
4124   return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
4125 }
4126 
replica_groups() const4127 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
4128   return Cast<HloCollectiveInstruction>(this)->replica_groups();
4129 }
4130 
4131 const std::vector<std::pair<int64, int64>>&
source_target_pairs() const4132 HloInstruction::source_target_pairs() const {
4133   return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
4134 }
4135 
channel_id() const4136 absl::optional<int64> HloInstruction::channel_id() const {
4137   return Cast<HloChannelInstruction>(this)->channel_id();
4138 }
4139 
set_channel_id(const absl::optional<int64> & channel_id)4140 void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
4141   return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
4142 }
4143 
4144 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const4145 HloInstruction::convolution_dimension_numbers() const {
4146   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4147     return convolution->convolution_dimension_numbers();
4148   }
4149   if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4150     return custom_call->convolution_dimension_numbers();
4151   }
4152   LOG(FATAL) << "Unimplemented method.";
4153 }
4154 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)4155 void HloInstruction::set_convolution_dimension_numbers(
4156     const ConvolutionDimensionNumbers& dnums) {
4157   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4158     convolution->set_convolution_dimension_numbers(dnums);
4159   } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4160     custom_call->set_convolution_dimension_numbers(dnums);
4161   } else {
4162     LOG(FATAL) << "Unimplemented method.";
4163   }
4164 }
4165 
feature_group_count() const4166 int64 HloInstruction::feature_group_count() const {
4167   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4168     return convolution->feature_group_count();
4169   }
4170   return Cast<HloCustomCallInstruction>(this)->feature_group_count();
4171 }
4172 
set_feature_group_count(int64 feature_group_count)4173 void HloInstruction::set_feature_group_count(int64 feature_group_count) {
4174   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4175     return convolution->set_feature_group_count(feature_group_count);
4176   }
4177   Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
4178       feature_group_count);
4179 }
4180 
batch_group_count() const4181 int64 HloInstruction::batch_group_count() const {
4182   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4183     return convolution->batch_group_count();
4184   }
4185   return Cast<HloCustomCallInstruction>(this)->batch_group_count();
4186 }
4187 
set_batch_group_count(int64 batch_group_count)4188 void HloInstruction::set_batch_group_count(int64 batch_group_count) {
4189   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4190     return convolution->set_batch_group_count(batch_group_count);
4191   }
4192   Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
4193       batch_group_count);
4194 }
4195 
select() const4196 HloComputation* HloInstruction::select() const {
4197   return Cast<HloSelectAndScatterInstruction>(this)->select();
4198 }
4199 
scatter() const4200 HloComputation* HloInstruction::scatter() const {
4201   return Cast<HloSelectAndScatterInstruction>(this)->scatter();
4202 }
4203 
set_select(HloComputation * computation)4204 void HloInstruction::set_select(HloComputation* computation) {
4205   return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
4206 }
4207 
set_scatter(HloComputation * computation)4208 void HloInstruction::set_scatter(HloComputation* computation) {
4209   return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
4210 }
4211 
custom_call_target() const4212 const string& HloInstruction::custom_call_target() const {
4213   return Cast<HloCustomCallInstruction>(this)->custom_call_target();
4214 }
4215 
padding_config() const4216 const PaddingConfig& HloInstruction::padding_config() const {
4217   return Cast<HloPadInstruction>(this)->padding_config();
4218 }
4219 
padding_type() const4220 PaddingType HloInstruction::padding_type() const {
4221   return Cast<HloCustomCallInstruction>(this)->padding_type();
4222 }
4223 
mutable_padding_config()4224 PaddingConfig* HloInstruction::mutable_padding_config() {
4225   return Cast<HloPadInstruction>(this)->mutable_padding_config();
4226 }
4227 
slice_sizes(int64 dimension) const4228 int64 HloInstruction::slice_sizes(int64 dimension) const {
4229   return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
4230 }
4231 
dynamic_slice_sizes() const4232 const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
4233   return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
4234 }
4235 
gather_dimension_numbers() const4236 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
4237   return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
4238 }
4239 
gather_slice_sizes() const4240 absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
4241   return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
4242 }
4243 
scatter_dimension_numbers() const4244 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
4245     const {
4246   return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
4247 }
4248 
dot_dimension_numbers() const4249 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
4250   return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
4251 }
4252 
operand_side_metadata() const4253 const DomainMetadata& HloInstruction::operand_side_metadata() const {
4254   return Cast<HloDomainInstruction>(this)->operand_side_metadata();
4255 }
4256 
user_side_metadata() const4257 const DomainMetadata& HloInstruction::user_side_metadata() const {
4258   return Cast<HloDomainInstruction>(this)->user_side_metadata();
4259 }
4260 
is_cross_program_prefetch() const4261 bool HloInstruction::is_cross_program_prefetch() const {
4262   return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
4263 }
4264 
comparison_direction() const4265 ComparisonDirection HloInstruction::comparison_direction() const {
4266   return Cast<HloCompareInstruction>(this)->direction();
4267 }
4268 
triangular_solve_options() const4269 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
4270   return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
4271 }
4272 
cholesky_options() const4273 const CholeskyOptions& HloInstruction::cholesky_options() const {
4274   return Cast<HloCholeskyInstruction>(this)->cholesky_options();
4275 }
4276 
4277 }  // namespace xla
4278