1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17 
18 #include <algorithm>
19 #include <ostream>
20 #include <set>
21 #include <unordered_set>
22 #include <utility>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/ascii.h"
30 #include "absl/strings/escaping.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/protobuf_util.h"
38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_computation.h"
41 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
42 #include "tensorflow/compiler/xla/service/hlo_module.h"
43 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
44 #include "tensorflow/compiler/xla/service/name_uniquer.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/platform/human_readable_json.h"
52 #include "tensorflow/core/platform/logging.h"
53 
54 namespace xla {
55 
56 using absl::CEscape;
57 using absl::StrAppend;
58 using absl::StrCat;
59 using absl::StrJoin;
60 
61 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64,HloComputation * > & computation_map)62 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
63     const HloInstructionProto& proto,
64     const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
65     const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
66   TF_RET_CHECK(!proto.opcode().empty());
67   HloOpcode opcode;
68   auto opcode_or = StringToHloOpcode(proto.opcode());
69   absl::optional<ComparisonDirection> comparison_direction;
70   if (opcode_or.ok()) {
71     opcode = opcode_or.ConsumeValueOrDie();
72   } else {
73     // Unknown opcode. Try auto-upgrading deprecated "less-than",
74     // "greater-than", etc opcodes, which are now rolled into the kCompare
75     // opcode.
76     if (proto.opcode() == "equal-to") {
77       comparison_direction = ComparisonDirection::kEq;
78     } else if (proto.opcode() == "not-equal-to") {
79       comparison_direction = ComparisonDirection::kNe;
80     } else if (proto.opcode() == "greater-than-or-equal-to") {
81       comparison_direction = ComparisonDirection::kGe;
82     } else if (proto.opcode() == "greater-than") {
83       comparison_direction = ComparisonDirection::kGt;
84     } else if (proto.opcode() == "less-than-or-equal-to") {
85       comparison_direction = ComparisonDirection::kLe;
86     } else if (proto.opcode() == "less-than") {
87       comparison_direction = ComparisonDirection::kLt;
88     }
89     if (comparison_direction) {
90       opcode = HloOpcode::kCompare;
91     } else {
92       return InvalidArgument("Unknown opcode: %s", proto.opcode());
93     }
94   }
95 
96   TF_RET_CHECK(proto.has_shape());
97 
98   std::unique_ptr<HloInstruction> instruction;
99   const auto operands = [&instruction_map, &proto](int index) {
100     return instruction_map.at(proto.operand_ids(index));
101   };
102   const auto all_operands = [&instruction_map, &proto]() {
103     std::vector<HloInstruction*> result(proto.operand_ids_size());
104     std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
105                    result.begin(), [&instruction_map](int64 operand_id) {
106                      return instruction_map.at(operand_id);
107                    });
108     return result;
109   };
110   const auto computations = [&computation_map, &proto](int index) {
111     return computation_map.at(proto.called_computation_ids(index));
112   };
113   const auto all_computations = [&computation_map, &proto]() {
114     std::vector<HloComputation*> result(proto.called_computation_ids_size());
115     std::transform(proto.called_computation_ids().begin(),
116                    proto.called_computation_ids().end(), result.begin(),
117                    [&computation_map](int64 computation_id) {
118                      return computation_map.at(computation_id);
119                    });
120     return result;
121   };
122 
123   TF_RET_CHECK(
124       absl::c_all_of(proto.operand_ids(),
125                      [&](int64 id) { return instruction_map.contains(id); }))
126       << proto.name() << " instruction contains invalid operand id(s)";
127 
128   TF_RET_CHECK(
129       absl::c_all_of(proto.called_computation_ids(),
130                      [&](int64 id) { return computation_map.contains(id); }))
131       << proto.name() << " instruction references invalid computation id(s)";
132 
133   Shape shape(proto.shape());
134   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
135 
136   absl::optional<int> arity = HloOpcodeArity(opcode);
137   if (arity) {
138     TF_RET_CHECK(proto.operand_ids_size() == *arity)
139         << proto.opcode() << " instruction should have " << *arity
140         << " operands but sees " << proto.operand_ids_size();
141   }
142 
143   switch (opcode) {
144     // Ops migrated to subclasses.
145     case HloOpcode::kBatchNormTraining:
146       instruction =
147           CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
148                                   proto.epsilon(), proto.feature_index());
149       break;
150     case HloOpcode::kBatchNormInference:
151       instruction = CreateBatchNormInference(
152           shape, operands(0), operands(1), operands(2), operands(3),
153           operands(4), proto.epsilon(), proto.feature_index());
154       break;
155     case HloOpcode::kBatchNormGrad:
156       instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
157                                         operands(2), operands(3), operands(4),
158                                         proto.epsilon(), proto.feature_index());
159       break;
160     case HloOpcode::kFft: {
161       std::vector<int64> fft_length(proto.fft_length().begin(),
162                                     proto.fft_length().end());
163       instruction = CreateFft(shape, operands(0), proto.fft_type(),
164                               absl::Span<const int64>(fft_length));
165       break;
166     }
167     case HloOpcode::kCompare: {
168       // Auto-upgraded from deprecated opcode skips the following.
169       if (!comparison_direction) {
170         TF_ASSIGN_OR_RETURN(
171             comparison_direction,
172             StringToComparisonDirection(proto.comparison_direction()));
173       }
174       instruction =
175           CreateCompare(shape, operands(0), operands(1), *comparison_direction);
176       break;
177     }
178     case HloOpcode::kTriangularSolve: {
179       instruction = CreateTriangularSolve(shape, operands(0), operands(1),
180                                           proto.triangular_solve_options());
181       break;
182     }
183     case HloOpcode::kCholesky: {
184       instruction =
185           CreateCholesky(shape, operands(0), proto.cholesky_options());
186       break;
187     }
188     case HloOpcode::kSend:
189       instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
190                                proto.is_host_transfer());
191       break;
192     case HloOpcode::kSendDone:
193       instruction = CreateSendDone(operands(0), proto.is_host_transfer());
194       break;
195     case HloOpcode::kRecv:
196       instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
197                                proto.channel_id(), proto.is_host_transfer());
198       break;
199     case HloOpcode::kRecvDone:
200       instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
201       break;
202     case HloOpcode::kReverse:
203       instruction = CreateReverse(shape, operands(0),
204                                   std::vector<int64>(proto.dimensions().begin(),
205                                                      proto.dimensions().end()));
206       break;
207     case HloOpcode::kConcatenate:
208       TF_RET_CHECK(proto.dimensions_size() == 1)
209           << "Concatenate instruction should have 1 dimension but sees "
210           << proto.dimensions_size();
211       instruction =
212           CreateConcatenate(shape, all_operands(), proto.dimensions(0));
213       break;
214     case HloOpcode::kConditional: {
215       TF_RET_CHECK(proto.called_computation_ids_size() > 0)
216           << "conditional should have at least 1 called computation";
217       if (operands(0)->shape().element_type() == PRED) {
218         TF_RET_CHECK(proto.called_computation_ids_size() == 2)
219             << "conditional should have exactly 2 called computations but got "
220             << proto.called_computation_ids_size();
221       }
222       TF_RET_CHECK(proto.operand_ids_size() ==
223                    proto.called_computation_ids_size() + 1)
224           << "conditional should have one branch_index operand plus one "
225              "operand per called computation but got "
226           << proto.operand_ids_size() << " operands for "
227           << proto.called_computation_ids_size() << " branch computations";
228       auto cond_operands = all_operands();
229       instruction =
230           CreateConditional(shape, cond_operands[0], all_computations(),
231                             absl::MakeSpan(cond_operands).subspan(1));
232       break;
233     }
234     case HloOpcode::kReduce:
235       TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
236           << "Reduce instruction should have an even number of operands but "
237              "sees "
238           << proto.operand_ids_size();
239       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
240           << "Reduce instruction should have 1 called computation but sees "
241           << proto.called_computation_ids_size();
242       {
243         const auto reduce_operands = all_operands();
244         auto inputs = absl::MakeSpan(reduce_operands)
245                           .subspan(0, reduce_operands.size() / 2);
246         auto init_values =
247             absl::MakeSpan(reduce_operands)
248                 .subspan(reduce_operands.size() / 2, reduce_operands.size());
249         instruction =
250             CreateReduce(shape, inputs, init_values,
251                          std::vector<int64>(proto.dimensions().begin(),
252                                             proto.dimensions().end()),
253                          computations(0));
254       }
255       break;
256     case HloOpcode::kSort: {
257       TF_RET_CHECK(proto.operand_ids_size() >= 1)
258           << "Sort instruction should have at least 1 operand but has "
259           << proto.operand_ids_size();
260       TF_RET_CHECK(proto.dimensions().size() == 1)
261           << "Sort instruction should have 1 dimension";
262       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
263           << "Sort instruction should one called computation but sees "
264           << proto.called_computation_ids_size();
265       auto sort_operands = all_operands();
266       instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
267                                computations(0), proto.is_stable());
268       break;
269     }
270     case HloOpcode::kTranspose:
271       instruction =
272           CreateTranspose(shape, operands(0),
273                           std::vector<int64>(proto.dimensions().begin(),
274                                              proto.dimensions().end()));
275       break;
276     case HloOpcode::kBroadcast:
277       instruction =
278           CreateBroadcast(shape, operands(0),
279                           std::vector<int64>(proto.dimensions().begin(),
280                                              proto.dimensions().end()));
281       break;
282     case HloOpcode::kMap:
283       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
284           << "Map instruction should have 1 called computation but sees "
285           << proto.called_computation_ids_size();
286       instruction = CreateMap(shape, all_operands(), computations(0));
287       break;
288     case HloOpcode::kSlice: {
289       std::vector<int64> slice_starts, slice_limits, slice_strides;
290       for (const HloInstructionProto::SliceDimensions& slice_dimensions :
291            proto.slice_dimensions()) {
292         slice_starts.push_back(slice_dimensions.start());
293         slice_limits.push_back(slice_dimensions.limit());
294         slice_strides.push_back(slice_dimensions.stride());
295       }
296       instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
297                                 slice_strides);
298       break;
299     }
300     case HloOpcode::kConstant: {
301       // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
302       if (proto.has_literal()) {
303         TF_ASSIGN_OR_RETURN(auto literal,
304                             Literal::CreateFromProto(proto.literal()));
305         instruction = CreateConstant(std::move(literal));
306       } else {
307         instruction = absl::make_unique<HloConstantInstruction>(shape);
308       }
309       break;
310     }
311     case HloOpcode::kTrace: {
312       TF_RET_CHECK(proto.has_literal());
313       TF_ASSIGN_OR_RETURN(auto literal,
314                           Literal::CreateFromProto(proto.literal()));
315       instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
316       break;
317     }
318     case HloOpcode::kFusion: {
319       // In the proto, fused computations are held exclusively within the
320       // HloInstructionProto and do not appear as an HloComputationProto within
321       // the HloModuleProto.
322       TF_RET_CHECK(!proto.fusion_kind().empty());
323       TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
324                           StringToFusionKind(proto.fusion_kind()));
325 
326       // Find the fused computation and set its fusion instruction.
327       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
328           << "Expect 1 called computation for fusion instruction but sees "
329           << proto.called_computation_ids_size();
330       const int64 fusion_id = proto.called_computation_ids(0);
331       auto* fused_computation =
332           tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
333       TF_RET_CHECK(fused_computation != nullptr)
334           << "No fusion computation with id " << fusion_id;
335       instruction =
336           CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
337       break;
338     }
339     case HloOpcode::kRng:
340       instruction = CreateRng(shape, proto.distribution(), all_operands());
341       break;
342     case HloOpcode::kParameter:
343       instruction =
344           CreateParameter(proto.parameter_number(), shape, proto.name());
345       if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
346         instruction->set_parameter_replicated_at_leaf_buffers(
347             proto.parameter_replication().replicated_at_leaf_buffers());
348       }
349       break;
350     case HloOpcode::kGetTupleElement:
351       instruction =
352           CreateGetTupleElement(shape, operands(0), proto.tuple_index());
353       break;
354     case HloOpcode::kReducePrecision:
355       instruction = CreateReducePrecision(
356           shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
357       break;
358     case HloOpcode::kInfeed: {
359       TF_RET_CHECK(shape.IsTuple() &&
360                    (ShapeUtil::TupleElementCount(shape) == 2))
361           << "Infeed should have a tuple shape with 2 operands, but has: "
362           << shape;
363       const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
364       instruction =
365           CreateInfeed(data_shape, operands(0), proto.infeed_config());
366     } break;
367     case HloOpcode::kOutfeed: {
368       Shape outfeed_shape(proto.outfeed_shape());
369       TF_RETURN_IF_ERROR(
370           ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
371       instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
372                                   proto.outfeed_config());
373       break;
374     }
375     case HloOpcode::kAllReduce: {
376       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
377           << "AllReduce should have 1 called computation but sees "
378           << proto.called_computation_ids_size();
379       absl::optional<int64> all_reduce_id;
380       if (proto.all_reduce_id() > 0) {
381         all_reduce_id = proto.all_reduce_id();
382       }
383       instruction = CreateAllReduce(
384           shape, all_operands(), computations(0),
385           /*replica_groups=*/
386           std::vector<ReplicaGroup>(proto.replica_groups().begin(),
387                                     proto.replica_groups().end()),
388           /*barrier=*/proto.all_reduce_barrier(),
389           /*all_reduce_id=*/all_reduce_id);
390       break;
391     }
392     case HloOpcode::kAllToAll: {
393       instruction = CreateAllToAll(
394           shape, all_operands(),
395           /*replica_groups=*/
396           std::vector<ReplicaGroup>(proto.replica_groups().begin(),
397                                     proto.replica_groups().end()));
398       break;
399     }
400     case HloOpcode::kCollectivePermute: {
401       std::vector<std::pair<int64, int64>> source_target_pairs(
402           proto.source_target_pairs_size());
403       for (int i = 0; i < source_target_pairs.size(); i++) {
404         source_target_pairs[i].first = proto.source_target_pairs(i).source();
405         source_target_pairs[i].second = proto.source_target_pairs(i).target();
406       }
407       instruction =
408           CreateCollectivePermute(shape, operands(0), source_target_pairs);
409       break;
410     }
411     case HloOpcode::kReplicaId: {
412       instruction = CreateReplicaId();
413       break;
414     }
415     case HloOpcode::kConvolution: {
416       TF_RET_CHECK(proto.has_window());
417       TF_RET_CHECK(proto.has_convolution_dimension_numbers());
418       PrecisionConfig precision_config = proto.precision_config();
419       precision_config.mutable_operand_precision()->Resize(
420           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
421       instruction = CreateConvolve(
422           shape, operands(0), operands(1),
423           std::max<int64>(proto.feature_group_count(), 1),
424           std::max<int64>(proto.batch_group_count(), 1), proto.window(),
425           proto.convolution_dimension_numbers(), precision_config);
426       break;
427     }
428     case HloOpcode::kReduceWindow:
429       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
430           << "ReduceWindow should have 1 called computation but sees "
431           << proto.called_computation_ids_size();
432       instruction = CreateReduceWindow(shape, operands(0), operands(1),
433                                        proto.window(), computations(0));
434       break;
435     case HloOpcode::kSelectAndScatter:
436       TF_RET_CHECK(proto.called_computation_ids_size() == 2)
437           << "SelectAndScatter should have 2 called computations but sees "
438           << proto.called_computation_ids_size();
439       instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
440                                            proto.window(), operands(1),
441                                            operands(2), computations(1));
442       break;
443     case HloOpcode::kCustomCall:
444       if (proto.constrain_layout()) {
445         // A proto RepeatedPtrField cannot be converted to a Span (it is a
446         // vector of pointers essentially) so create a vector of shapes to pass
447         // in.
448         std::vector<Shape> operand_shapes;
449         for (const ShapeProto& shape_proto :
450              proto.operand_shapes_with_layout()) {
451           operand_shapes.emplace_back(shape_proto);
452         }
453         instruction =
454             CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
455                              operand_shapes, proto.custom_call_opaque());
456       } else {
457         instruction =
458             CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
459                              proto.custom_call_opaque());
460       }
461       if (proto.has_window()) {
462         static_cast<HloCustomCallInstruction*>(instruction.get())
463             ->set_window(proto.window());
464       }
465       if (proto.has_convolution_dimension_numbers()) {
466         static_cast<HloCustomCallInstruction*>(instruction.get())
467             ->set_convolution_dimension_numbers(
468                 proto.convolution_dimension_numbers());
469       }
470       static_cast<HloCustomCallInstruction*>(instruction.get())
471           ->set_feature_group_count(
472               std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
473       static_cast<HloCustomCallInstruction*>(instruction.get())
474           ->set_batch_group_count(
475               std::max(static_cast<int64>(proto.batch_group_count()), 1LL));
476       break;
477     case HloOpcode::kPad:
478       TF_RET_CHECK(proto.has_padding_config());
479       instruction =
480           CreatePad(shape, operands(0), operands(1), proto.padding_config());
481       break;
482     case HloOpcode::kDynamicSlice: {
483       std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
484       absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
485       TF_RET_CHECK(proto.operand_ids_size() >= 1)
486           << "DynamicSlice instruction should have at least 1 operands but "
487              "sees "
488           << proto.operand_ids_size();
489       // TODO(b/118437727): Old form, make the check unconditional.
490       if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
491         auto expected_operands = 1 + operands(0)->shape().rank();
492         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
493             << "DynamicSlice instruction should have " << expected_operands
494             << " operands, but has " << proto.operand_ids_size();
495       }
496       const auto& operand_vector = all_operands();
497       instruction = CreateDynamicSlice(
498           shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
499           slice_sizes);
500       break;
501     }
502     case HloOpcode::kDynamicUpdateSlice: {
503       TF_RET_CHECK(proto.operand_ids_size() >= 2)
504           << "DynamicUpdateSlice instruction should have at least 2 operands "
505              "but sees "
506           << proto.operand_ids_size();
507       // TODO(b/118437727): Old form, make the check unconditional.
508       if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
509         auto expected_operands = 2 + operands(0)->shape().rank();
510         TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
511             << "DynamicUpdateSlice instruction should have "
512             << expected_operands << " operands, but has "
513             << proto.operand_ids_size();
514       }
515       const auto& operand_vector = all_operands();
516       instruction =
517           CreateDynamicUpdateSlice(shape, operands(0), operands(1),
518                                    absl::MakeSpan(operand_vector).subspan(2));
519 
520       break;
521     }
522     case HloOpcode::kGather: {
523       TF_RET_CHECK(proto.has_gather_dimension_numbers())
524           << "Gather instruction should have GatherDimensionNumbers set.";
525       std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
526           absl::make_unique<GatherDimensionNumbers>(
527               proto.gather_dimension_numbers());
528       std::vector<int64> gather_slice_sizes;
529       for (int64 bound : proto.gather_slice_sizes()) {
530         gather_slice_sizes.push_back(bound);
531       }
532       instruction = CreateGather(shape, operands(0), operands(1),
533                                  *gather_dimension_numbers, gather_slice_sizes);
534       break;
535     }
536     case HloOpcode::kScatter: {
537       TF_RET_CHECK(proto.has_scatter_dimension_numbers())
538           << "Scatter instruction should have ScatterDimensionNumbers set.";
539       TF_RET_CHECK(proto.called_computation_ids_size() == 1)
540           << "Scatter instruction should have 1 called computation but sees "
541           << proto.called_computation_ids_size();
542       auto scatter_dimension_numbers =
543           absl::make_unique<ScatterDimensionNumbers>(
544               proto.scatter_dimension_numbers());
545       instruction = CreateScatter(shape, operands(0), operands(1), operands(2),
546                                   computations(0), *scatter_dimension_numbers);
547       break;
548     }
549     case HloOpcode::kIota:
550       TF_RET_CHECK(proto.dimensions_size() == 1)
551           << "Iota instruction should have 1 dimension but sees "
552           << proto.dimensions_size();
553       instruction = CreateIota(shape, proto.dimensions(0));
554       break;
555     case HloOpcode::kDot: {
556       TF_RET_CHECK(proto.has_dot_dimension_numbers())
557           << "Dot instruction should have dot_dimension_numbers.";
558       PrecisionConfig precision_config = proto.precision_config();
559       precision_config.mutable_operand_precision()->Resize(
560           proto.operand_ids_size(), PrecisionConfig::DEFAULT);
561       instruction = absl::make_unique<HloDotInstruction>(
562           shape, operands(0), operands(1), proto.dot_dimension_numbers(),
563           precision_config);
564       break;
565     }
566     case HloOpcode::kDomain: {
567       std::shared_ptr<const HloSharding> entry_hlo_sharding;
568       std::shared_ptr<const HloSharding> exit_hlo_sharding;
569       if (proto.has_domain_entry_sharding()) {
570         TF_ASSIGN_OR_RETURN(
571             HloSharding sharding,
572             HloSharding::FromProto(proto.domain_entry_sharding()));
573         entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
574       }
575       if (proto.has_domain_exit_sharding()) {
576         TF_ASSIGN_OR_RETURN(
577             HloSharding sharding,
578             HloSharding::FromProto(proto.domain_exit_sharding()));
579         exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
580       }
581       instruction = absl::make_unique<HloDomainInstruction>(
582           shape, operands(0),
583           absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
584           absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
585       break;
586     }
587     case HloOpcode::kGetDimensionSize:
588       TF_RET_CHECK(proto.dimensions_size() == 1);
589       instruction =
590           CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
591       break;
592     default: {
593       instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
594       for (const int64 operand_id : proto.operand_ids()) {
595         instruction->AppendOperand(instruction_map.at(operand_id));
596       }
597       if (instruction->opcode() != HloOpcode::kFusion) {
598         for (const int64 computation_id : proto.called_computation_ids()) {
599           instruction->called_computations_.push_back(
600               computation_map.at(computation_id));
601         }
602       }
603       TF_RET_CHECK(!proto.has_precision_config())
604           << instruction->opcode() << proto.DebugString();
605       TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
606       break;
607     }
608   }
609 
610   for (const int64 predecessor_id : proto.control_predecessor_ids()) {
611     TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
612         << "No instruction with id " << predecessor_id;
613     TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
614                            ->AddControlDependencyTo(instruction.get()));
615   }
616 
617   TF_RET_CHECK(!proto.name().empty());
618   instruction->SetAndSanitizeName(proto.name());
619   instruction->metadata_ = proto.metadata();
620   instruction->backend_config_ = proto.backend_config();
621 
622   TF_RET_CHECK(proto.id() >= 0)
623       << "Instruction with negative id: " << proto.id();
624   TF_RET_CHECK(proto.id() <= INT_MAX)
625       << "Instruction with id > INT_MAX: " << proto.id();
626   instruction->unique_id_ = proto.id();
627 
628   if (proto.has_sharding()) {
629     TF_ASSIGN_OR_RETURN(const auto& sharding,
630                         HloSharding::FromProto(proto.sharding()));
631     instruction->set_sharding(sharding);
632   }
633 
634   return std::move(instruction);
635 }
636 
CreateParameter(int64 parameter_number,const Shape & shape,const string & name)637 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
638     int64 parameter_number, const Shape& shape, const string& name) {
639   return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
640                                                     name);
641 }
642 
CreateTrace(const string & tag,HloInstruction * operand)643 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
644     const string& tag, HloInstruction* operand) {
645   return absl::make_unique<HloTraceInstruction>(tag, operand);
646 }
647 
CreateConstant(Literal literal)648 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
649     Literal literal) {
650   return absl::make_unique<HloConstantInstruction>(std::move(literal));
651 }
652 
CreateIota(const Shape & shape,int64 iota_dimension)653 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
654     const Shape& shape, int64 iota_dimension) {
655   return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
656 }
657 
658 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64 index)659 HloInstruction::CreateGetTupleElement(const Shape& shape,
660                                       HloInstruction* operand, int64 index) {
661   return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
662                                                           index);
663 }
664 
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)665 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
666     const Shape& shape, RandomDistribution distribution,
667     absl::Span<HloInstruction* const> parameters) {
668   return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
669 }
670 
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)671 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
672     const Shape& shape, HloOpcode opcode,
673     absl::Span<HloInstruction* const> operands) {
674   if (opcode == HloOpcode::kCopy) {
675     // It is impossible to copy an opaque shape, we don't know how big it is.
676     CHECK(!shape.IsOpaque());
677   }
678   auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
679   for (auto operand : operands) {
680     instruction->AppendOperand(operand);
681   }
682   return instruction;
683 }
684 
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)685 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
686     const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
687   // Only certain opcodes are supported with CreateUnary: opcodes of unary
688   // instructions with no auxiliary fields.
689   switch (opcode) {
690     case HloOpcode::kAbs:
691     case HloOpcode::kRoundNearestAfz:
692     case HloOpcode::kBitcast:
693     case HloOpcode::kCeil:
694     case HloOpcode::kCopy:
695     case HloOpcode::kCos:
696     case HloOpcode::kClz:
697     case HloOpcode::kExp:
698     case HloOpcode::kExpm1:
699     case HloOpcode::kFloor:
700     case HloOpcode::kImag:
701     case HloOpcode::kIsFinite:
702     case HloOpcode::kLog:
703     case HloOpcode::kLog1p:
704     case HloOpcode::kNot:
705     case HloOpcode::kNegate:
706     case HloOpcode::kReal:
707     case HloOpcode::kRsqrt:
708     case HloOpcode::kSign:
709     case HloOpcode::kSin:
710     case HloOpcode::kSqrt:
711     case HloOpcode::kTanh:
712       break;
713     default:
714       LOG(FATAL) << "Invalid unary instruction opcode "
715                  << HloOpcodeString(opcode);
716   }
717   return CreateNary(shape, opcode, {operand});
718 }
719 
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)720 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
721     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
722     HloInstruction* rhs) {
723   // Only certain opcodes are supported with CreateBinary: opcodes of binary
724   // instructions with no auxiliary fields.
725   switch (opcode) {
726     case HloOpcode::kAdd:
727     case HloOpcode::kAtan2:
728     case HloOpcode::kDivide:
729     case HloOpcode::kComplex:
730     case HloOpcode::kMaximum:
731     case HloOpcode::kMinimum:
732     case HloOpcode::kMultiply:
733     case HloOpcode::kPower:
734     case HloOpcode::kRemainder:
735     case HloOpcode::kSubtract:
736     case HloOpcode::kAnd:
737     case HloOpcode::kOr:
738     case HloOpcode::kXor:
739     case HloOpcode::kShiftLeft:
740     case HloOpcode::kShiftRightArithmetic:
741     case HloOpcode::kShiftRightLogical:
742       break;
743     default:
744       LOG(FATAL) << "Invalid binary instruction opcode "
745                  << HloOpcodeString(opcode);
746   }
747   return CreateNary(shape, opcode, {lhs, rhs});
748 }
749 
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)750 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
751     const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
752     HloInstruction* rhs, HloInstruction* ehs) {
753   // Only certain opcodes are supported with CreateTernary: opcodes of ternary
754   // instructions with no auxiliary fields.
755   switch (opcode) {
756     case HloOpcode::kClamp:
757     case HloOpcode::kSelect:
758     case HloOpcode::kTupleSelect:
759       break;
760     default:
761       LOG(FATAL) << "Invalid ternary instruction opcode "
762                  << HloOpcodeString(opcode);
763   }
764   return CreateNary(shape, opcode, {lhs, rhs, ehs});
765 }
766 
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)767 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
768     const Shape& shape, HloOpcode opcode,
769     absl::Span<HloInstruction* const> operands) {
770   CHECK_EQ(HloOpcode::kTuple, opcode);
771   return CreateNary(shape, opcode, operands);
772 }
773 
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)774 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
775     const Shape& shape, absl::Span<HloInstruction* const> operands,
776     HloComputation* map_computation) {
777   return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
778 }
779 
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)780 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
781     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
782     int64 feature_group_count, int64 batch_group_count, const Window& window,
783     const ConvolutionDimensionNumbers& dimension_numbers,
784     const PrecisionConfig& precision_config) {
785   return absl::make_unique<HloConvolutionInstruction>(
786       shape, lhs, rhs, feature_group_count, batch_group_count, window,
787       dimension_numbers, precision_config);
788 }
789 
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)790 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
791     const Shape& shape, HloInstruction* operand, FftType fft_type,
792     absl::Span<const int64> fft_length) {
793   return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
794                                               fft_length);
795 }
796 
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction)797 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
798     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
799     ComparisonDirection direction) {
800   return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
801 }
802 
803 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)804 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
805                                       HloInstruction* b,
806                                       const TriangularSolveOptions& options) {
807   return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
808 }
809 
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)810 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
811     const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
812   return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
813 }
814 
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)815 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
816     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
817     const DotDimensionNumbers& dimension_numbers,
818     const PrecisionConfig& precision_config) {
819   return absl::make_unique<HloDotInstruction>(
820       shape, lhs, rhs, dimension_numbers, precision_config);
821 }
822 
823 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)824 HloInstruction::CreateReducePrecision(const Shape& shape,
825                                       HloInstruction* operand,
826                                       const int exponent_bits,
827                                       const int mantissa_bits) {
828   return absl::make_unique<HloReducePrecisionInstruction>(
829       shape, operand, exponent_bits, mantissa_bits);
830 }
831 
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,absl::string_view barrier,const absl::optional<int64> & all_reduce_id)832 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
833     const Shape& shape, absl::Span<HloInstruction* const> operands,
834     HloComputation* reduce_computation,
835     const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
836     const absl::optional<int64>& all_reduce_id) {
837   return absl::make_unique<HloAllReduceInstruction>(
838       shape, operands, reduce_computation, replica_groups, barrier,
839       all_reduce_id);
840 }
841 
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups)842 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
843     const Shape& shape, absl::Span<HloInstruction* const> operands,
844     const std::vector<ReplicaGroup>& replica_groups) {
845   return absl::make_unique<HloAllToAllInstruction>(shape, operands,
846                                                    replica_groups);
847 }
848 
849 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)850 HloInstruction::CreateCollectivePermute(
851     const Shape& shape, HloInstruction* operand,
852     const std::vector<std::pair<int64, int64>>& source_target_pairs) {
853   return absl::make_unique<HloCollectivePermuteInstruction>(
854       shape, operand, source_target_pairs);
855 }
856 
CreateReplicaId()857 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
858   return absl::WrapUnique(
859       new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {})));
860 }
861 
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)862 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
863     const Shape& infeed_shape, HloInstruction* token_operand,
864     const string& config) {
865   return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
866                                                  config);
867 }
868 
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)869 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
870     const Shape& outfeed_shape, HloInstruction* operand,
871     HloInstruction* token_operand, absl::string_view outfeed_config) {
872   return absl::make_unique<HloOutfeedInstruction>(
873       outfeed_shape, operand, token_operand, outfeed_config);
874 }
875 
CreateSend(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)876 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
877     HloInstruction* operand, HloInstruction* token, int64 channel_id,
878     bool is_host_transfer) {
879   return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
880                                                is_host_transfer);
881 }
882 
CreateSendDone(HloInstruction * operand,bool is_host_transfer)883 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
884     HloInstruction* operand, bool is_host_transfer) {
885   auto send_operand = DynCast<HloSendInstruction>(operand);
886   CHECK(send_operand != nullptr)
887       << "SendDone must take the context operand from Send";
888   return absl::make_unique<HloSendDoneInstruction>(send_operand,
889                                                    is_host_transfer);
890 }
891 
CreateRecv(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)892 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
893     const Shape& shape, HloInstruction* token, int64 channel_id,
894     bool is_host_transfer) {
895   return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
896                                                is_host_transfer);
897 }
898 
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)899 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
900     HloInstruction* operand, bool is_host_transfer) {
901   auto recv_operand = DynCast<HloRecvInstruction>(operand);
902   CHECK(recv_operand != nullptr)
903       << "RecvDone must take the context operand from Recv";
904   return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
905                                                    is_host_transfer);
906 }
907 
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)908 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
909     const Shape& shape, HloInstruction* operand,
910     absl::Span<const int64> dimensions) {
911   return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
912 }
913 
CreateAfterAll(absl::Span<HloInstruction * const> operands)914 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
915     absl::Span<HloInstruction* const> operands) {
916   CHECK(!operands.empty());
917   auto instruction = absl::WrapUnique(
918       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
919   for (auto operand : operands) {
920     instruction->AppendOperand(operand);
921   }
922   return instruction;
923 }
924 
CreateToken()925 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
926   return absl::WrapUnique(
927       new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
928 }
929 
930 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)931 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
932                                     HloInstruction* token_operand) {
933   auto instruction = absl::WrapUnique(
934       new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
935   instruction->AppendOperand(data_operand);
936   instruction->AppendOperand(token_operand);
937   return instruction;
938 }
939 
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)940 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
941     const Shape& shape, HloComputation* condition, HloComputation* body,
942     HloInstruction* init) {
943   auto instruction =
944       absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
945   instruction->AppendOperand(init);
946   // Body comes before condition computation in the vector.
947   instruction->called_computations_.push_back(body);
948   instruction->called_computations_.push_back(condition);
949   return instruction;
950 }
951 
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)952 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
953     const Shape& shape, HloInstruction* pred,
954     HloInstruction* true_computation_arg, HloComputation* true_computation,
955     HloInstruction* false_computation_arg, HloComputation* false_computation) {
956   auto instruction =
957       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
958   instruction->AppendOperand(pred);
959   instruction->AppendOperand(true_computation_arg);
960   instruction->AppendOperand(false_computation_arg);
961   // In called_computations_, the index of true_computation must be 0 and that
962   // of false computation must be 1, as defined by kTrueComputationIndex and
963   // kFalseComputationIndex.
964   instruction->called_computations_.push_back(true_computation);
965   instruction->called_computations_.push_back(false_computation);
966   return instruction;
967 }
968 
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)969 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
970     const Shape& shape, HloInstruction* branch_index,
971     absl::Span<HloComputation* const> branch_computations,
972     absl::Span<HloInstruction* const> branch_computation_args) {
973   auto instruction =
974       absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
975   instruction->AppendOperand(branch_index);
976   CHECK_EQ(branch_computations.size(), branch_computation_args.size());
977   for (int i = 0; i < branch_computations.size(); ++i) {
978     instruction->called_computations_.push_back(branch_computations[i]);
979     instruction->AppendOperand(branch_computation_args[i]);
980   }
981   return instruction;
982 }
983 
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)984 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
985     const Shape& shape, HloInstruction* operand,
986     absl::Span<const int64> start_indices,
987     absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
988   return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
989                                                 limit_indices, strides);
990 }
991 
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)992 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
993     const Shape& shape, HloInstruction* operand,
994     absl::Span<HloInstruction* const> start_indices,
995     absl::Span<const int64> slice_sizes) {
996   return absl::make_unique<HloDynamicSliceInstruction>(
997       shape, operand, start_indices, slice_sizes);
998 }
999 
1000 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1001 HloInstruction::CreateDynamicUpdateSlice(
1002     const Shape& shape, HloInstruction* operand, HloInstruction* update,
1003     absl::Span<HloInstruction* const> start_indices) {
1004   return absl::make_unique<HloDynamicUpdateSliceInstruction>(
1005       shape, operand, update, start_indices);
1006 }
1007 
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)1008 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1009     const Shape& shape, absl::Span<HloInstruction* const> operands,
1010     int64 dimension) {
1011   return absl::make_unique<HloConcatenateInstruction>(shape, operands,
1012                                                       dimension);
1013 }
1014 
CreateConvert(const Shape & shape,HloInstruction * operand)1015 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1016     const Shape& shape, HloInstruction* operand) {
1017   auto instruction =
1018       absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1019   instruction->AppendOperand(operand);
1020   return instruction;
1021 }
1022 
1023 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1024 HloInstruction::CreateBitcastConvert(const Shape& shape,
1025                                      HloInstruction* operand) {
1026   auto instruction =
1027       absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1028   instruction->AppendOperand(operand);
1029   return instruction;
1030 }
1031 
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1032 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1033     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1034     absl::Span<const int64> dimensions_to_reduce,
1035     HloComputation* reduce_computation) {
1036   auto instruction = absl::WrapUnique(new HloReduceInstruction(
1037       shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1038   return std::move(instruction);
1039 }
1040 
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1041 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1042     const Shape& shape, absl::Span<HloInstruction* const> operands,
1043     absl::Span<HloInstruction* const> init_values,
1044     absl::Span<const int64> dimensions_to_reduce,
1045     HloComputation* reduce_computation) {
1046   std::vector<HloInstruction*> all_args;
1047   all_args.reserve(operands.size() * 2);
1048   all_args.insert(all_args.end(), operands.begin(), operands.end());
1049   all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1050   return absl::make_unique<HloReduceInstruction>(
1051       shape, all_args, dimensions_to_reduce, reduce_computation);
1052 }
1053 
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1054 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1055     const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1056     const Window& window, HloComputation* reduce_computation) {
1057   return absl::make_unique<HloReduceWindowInstruction>(
1058       shape, operand, init_value, window, reduce_computation);
1059 }
1060 
1061 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)1062 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1063                                         HloInstruction* operand,
1064                                         HloInstruction* scale,
1065                                         HloInstruction* offset, float epsilon,
1066                                         int64 feature_index) {
1067   return absl::make_unique<HloBatchNormTrainingInstruction>(
1068       shape, operand, scale, offset, epsilon, feature_index);
1069 }
1070 
1071 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)1072 HloInstruction::CreateBatchNormInference(
1073     const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1074     HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1075     float epsilon, int64 feature_index) {
1076   return absl::make_unique<HloBatchNormInferenceInstruction>(
1077       shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1078 }
1079 
1080 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)1081 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1082                                     HloInstruction* scale, HloInstruction* mean,
1083                                     HloInstruction* variance,
1084                                     HloInstruction* grad_output, float epsilon,
1085                                     int64 feature_index) {
1086   return absl::make_unique<HloBatchNormGradInstruction>(
1087       shape, operand, scale, mean, variance, grad_output, epsilon,
1088       feature_index);
1089 }
1090 
1091 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1092 HloInstruction::CreateSelectAndScatter(
1093     const Shape& shape, HloInstruction* operand, HloComputation* select,
1094     const Window& window, HloInstruction* source, HloInstruction* init_value,
1095     HloComputation* scatter) {
1096   return absl::make_unique<HloSelectAndScatterInstruction>(
1097       shape, operand, select, window, source, init_value, scatter);
1098 }
1099 
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimensions)1100 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1101     const Shape& shape, HloInstruction* operand,
1102     absl::Span<const int64> broadcast_dimensions) {
1103   return absl::make_unique<HloBroadcastInstruction>(shape, operand,
1104                                                     broadcast_dimensions);
1105 }
1106 
1107 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64 dimension)1108 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1109                                        HloInstruction* operand,
1110                                        int64 dimension) {
1111   return absl::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1112                                                            dimension);
1113 }
1114 
1115 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1116 HloInstruction::CreateBroadcastSequence(
1117     const Shape& output_shape, HloInstruction* operand,
1118     const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1119         adder) {
1120   CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1121         operand->shape().rank() == output_shape.rank());
1122   Shape broadcast_shape = ShapeUtil::ChangeElementType(
1123       output_shape, operand->shape().element_type());
1124   // Do explicit broadcast for scalar.
1125   if (ShapeUtil::IsScalar(operand->shape())) {
1126     auto broadcast =
1127         HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1128     broadcast->set_metadata(operand->metadata());
1129     if (operand->has_sharding()) {
1130       broadcast->set_sharding(operand->sharding());
1131     }
1132     return broadcast;
1133   }
1134   // Do explicit broadcast for degenerate broadcast.
1135   std::vector<int64> broadcast_dimensions;
1136   std::vector<int64> reshaped_dimensions;
1137   for (int i = 0; i < operand->shape().rank(); i++) {
1138     if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1139       broadcast_dimensions.push_back(i);
1140       reshaped_dimensions.push_back(operand->shape().dimensions(i));
1141     } else {
1142       CHECK_EQ(operand->shape().dimensions(i), 1)
1143           << "An explicit broadcast sequence requires the broadcasted "
1144              "dimensions to be trivial; operand: "
1145           << operand->ToString() << "; output_shape: " << output_shape;
1146     }
1147   }
1148   // Eliminate the size one dimensions.
1149   HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1150       ShapeUtil::MakeShape(operand->shape().element_type(),
1151                            reshaped_dimensions),
1152       operand));
1153   reshaped_operand->set_metadata(operand->metadata());
1154   if (operand->has_sharding()) {
1155     reshaped_operand->set_sharding(operand->sharding());
1156   }
1157   // Broadcast 'reshape' up to the larger size.
1158   auto broadcast = HloInstruction::CreateBroadcast(
1159       broadcast_shape, reshaped_operand, broadcast_dimensions);
1160   broadcast->set_metadata(operand->metadata());
1161   if (operand->has_sharding()) {
1162     broadcast->set_sharding(operand->sharding());
1163   }
1164   return broadcast;
1165 }
1166 
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1167 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1168     const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1169     const PaddingConfig& padding_config) {
1170   return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
1171                                               padding_config);
1172 }
1173 
CreateReshape(const Shape & shape,HloInstruction * operand)1174 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1175     const Shape& shape, HloInstruction* operand) {
1176   CHECK_EQ(ShapeUtil::ElementsIn(shape),
1177            ShapeUtil::ElementsIn(operand->shape()))
1178       << "shape: " << ShapeUtil::HumanString(shape)
1179       << " operand: " << ShapeUtil::HumanString(operand->shape());
1180   auto instruction =
1181       absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
1182   instruction->AppendOperand(operand);
1183   return instruction;
1184 }
1185 
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1186 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1187     const Shape& shape, HloInstruction* operand,
1188     absl::Span<const int64> dimensions) {
1189   return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1190 }
1191 
CreateSort(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1192 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1193     const Shape& shape, int64 dimension,
1194     absl::Span<HloInstruction* const> operands, HloComputation* compare,
1195     bool is_stable) {
1196   return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
1197                                                compare, is_stable);
1198 }
1199 
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1200 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1201     const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1202   return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
1203                                                  fused_root);
1204 }
1205 
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1206 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1207     const Shape& shape, FusionKind fusion_kind,
1208     absl::Span<HloInstruction* const> operands,
1209     HloComputation* fusion_computation) {
1210   return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1211                                                  fusion_computation);
1212 }
1213 
set_single_sharding(const HloSharding & sharding)1214 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1215   CHECK(!sharding.IsTuple()) << sharding;
1216   if (shape().IsTuple()) {
1217     set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1218   } else {
1219     set_sharding(sharding);
1220   }
1221 }
1222 
SetupDerivedInstruction(HloInstruction * derived_instruction) const1223 void HloInstruction::SetupDerivedInstruction(
1224     HloInstruction* derived_instruction) const {
1225   if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType(
1226                                   shape_, derived_instruction->shape())) {
1227     // Only copy sharding if the shape of the two instruction is compatible
1228     // because copying it between differently shaped instructions can produce
1229     // invalid shardings.
1230     derived_instruction->set_sharding(*sharding_);
1231   } else {
1232     derived_instruction->clear_sharding();
1233   }
1234   derived_instruction->set_metadata(metadata_);
1235 }
1236 
HasSideEffectNoRecurse() const1237 bool HloInstruction::HasSideEffectNoRecurse() const {
1238   switch (opcode_) {
1239     case HloOpcode::kSend:
1240     case HloOpcode::kSendDone:
1241     case HloOpcode::kRecv:
1242     case HloOpcode::kRecvDone:
1243     case HloOpcode::kRng:
1244     case HloOpcode::kInfeed:
1245     case HloOpcode::kOutfeed:
1246     case HloOpcode::kTrace:
1247       return true;
1248     case HloOpcode::kAllReduce:
1249       return all_reduce_id().has_value();
1250     default:
1251       return false;
1252   }
1253 }
1254 
HasSideEffect() const1255 bool HloInstruction::HasSideEffect() const {
1256   if (HasSideEffectNoRecurse()) {
1257     return true;
1258   }
1259   // Check if any of the called computations has a side effect.
1260   for (const auto& computation : called_computations()) {
1261     if (computation->HasSideEffect()) {
1262       return true;
1263     }
1264   }
1265   return false;
1266 }
1267 
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1268 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1269     const Shape& shape, absl::Span<HloInstruction* const> operands,
1270     HloComputation* computation) {
1271   std::unique_ptr<HloInstruction> instruction =
1272       absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
1273   for (auto operand : operands) {
1274     instruction->AppendOperand(operand);
1275   }
1276   instruction->called_computations_.push_back(computation);
1277   return instruction;
1278 }
1279 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::string_view opaque)1280 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1281     const Shape& shape, absl::Span<HloInstruction* const> operands,
1282     absl::string_view custom_call_target, absl::string_view opaque) {
1283   return absl::make_unique<HloCustomCallInstruction>(
1284       shape, operands, custom_call_target, opaque);
1285 }
1286 
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,absl::string_view opaque)1287 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1288     const Shape& shape, absl::Span<HloInstruction* const> operands,
1289     absl::string_view custom_call_target,
1290     absl::Span<const Shape> operand_shapes_with_layout,
1291     absl::string_view opaque) {
1292   return absl::make_unique<HloCustomCallInstruction>(
1293       shape, operands, custom_call_target, opaque, operand_shapes_with_layout);
1294 }
1295 
CreateTuple(absl::Span<HloInstruction * const> elements)1296 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1297     absl::Span<HloInstruction* const> elements) {
1298   std::vector<Shape> element_shapes;
1299   for (auto element : elements) {
1300     element_shapes.push_back(element->shape());
1301   }
1302   Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
1303   return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1304 }
1305 
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)1306 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1307     const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1308     const GatherDimensionNumbers& gather_dim_numbers,
1309     absl::Span<const int64> slice_sizes) {
1310   return absl::make_unique<HloGatherInstruction>(
1311       shape, operand, start_indices, gather_dim_numbers, slice_sizes);
1312 }
1313 
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers)1314 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1315     const Shape& shape, HloInstruction* operand,
1316     HloInstruction* scatter_indices, HloInstruction* updates,
1317     HloComputation* update_computation,
1318     const ScatterDimensionNumbers& scatter_dim_numbers) {
1319   return absl::make_unique<HloScatterInstruction>(
1320       shape, operand, scatter_indices, updates, update_computation,
1321       scatter_dim_numbers);
1322 }
1323 
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1324 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1325     const Shape& shape, HloInstruction* operand,
1326     std::unique_ptr<DomainMetadata> operand_side_metadata,
1327     std::unique_ptr<DomainMetadata> user_side_metadata) {
1328   return absl::make_unique<HloDomainInstruction>(
1329       shape, operand, std::move(operand_side_metadata),
1330       std::move(user_side_metadata));
1331 }
1332 
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1333 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1334     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1335     HloCloneContext* context) const {
1336   VLOG(3) << "CloneWithNewOperands:\n  " << ToString();
1337   VLOG(3) << "  new operands:";
1338   for (const HloInstruction* new_operand : new_operands) {
1339     VLOG(3) << "    %" << new_operand->name();
1340   }
1341 
1342   std::unique_ptr<HloInstruction> clone;
1343   // Explicitly call the factory for the instruction type. This is more robust
1344   // in the face of code changes than copying fields explicitly. This also
1345   // properly sets the user fields of the operands.
1346   switch (opcode_) {
1347     // Ops migrated to subclasses.
1348     // TODO(b/80131774): Remove this switch when migration is complete.
1349     case HloOpcode::kBatchNormTraining:
1350     case HloOpcode::kBatchNormInference:
1351     case HloOpcode::kBatchNormGrad:
1352     case HloOpcode::kFft:
1353     case HloOpcode::kCompare:
1354     case HloOpcode::kSend:
1355     case HloOpcode::kSendDone:
1356     case HloOpcode::kRecv:
1357     case HloOpcode::kRecvDone:
1358     case HloOpcode::kReverse:
1359     case HloOpcode::kConcatenate:
1360     case HloOpcode::kReduce:
1361     case HloOpcode::kTranspose:
1362     case HloOpcode::kBroadcast:
1363     case HloOpcode::kMap:
1364     case HloOpcode::kSlice:
1365     case HloOpcode::kConstant:
1366     case HloOpcode::kTrace:
1367     case HloOpcode::kFusion:
1368     case HloOpcode::kRng:
1369     case HloOpcode::kParameter:
1370     case HloOpcode::kGetTupleElement:
1371     case HloOpcode::kReducePrecision:
1372     case HloOpcode::kAllReduce:
1373     case HloOpcode::kAllToAll:
1374     case HloOpcode::kCollectivePermute:
1375     case HloOpcode::kInfeed:
1376     case HloOpcode::kOutfeed:
1377     case HloOpcode::kConvolution:
1378     case HloOpcode::kCustomCall:
1379     case HloOpcode::kReduceWindow:
1380     case HloOpcode::kSelectAndScatter:
1381     case HloOpcode::kPad:
1382     case HloOpcode::kDynamicSlice:
1383     case HloOpcode::kSort:
1384     case HloOpcode::kGather:
1385     case HloOpcode::kScatter:
1386     case HloOpcode::kIota:
1387     case HloOpcode::kDot:
1388     case HloOpcode::kDomain:
1389     case HloOpcode::kGetDimensionSize:
1390     case HloOpcode::kTriangularSolve:
1391     case HloOpcode::kCholesky:
1392       clone = CloneWithNewOperandsImpl(shape, new_operands, context);
1393       break;
1394     // Unary ops.
1395     case HloOpcode::kAbs:
1396     case HloOpcode::kRoundNearestAfz:
1397     case HloOpcode::kBitcast:
1398     case HloOpcode::kCeil:
1399     case HloOpcode::kClz:
1400     case HloOpcode::kCopy:
1401     case HloOpcode::kCos:
1402     case HloOpcode::kExp:
1403     case HloOpcode::kExpm1:
1404     case HloOpcode::kImag:
1405     case HloOpcode::kIsFinite:
1406     case HloOpcode::kFloor:
1407     case HloOpcode::kLog:
1408     case HloOpcode::kLog1p:
1409     case HloOpcode::kNot:
1410     case HloOpcode::kNegate:
1411     case HloOpcode::kReal:
1412     case HloOpcode::kRsqrt:
1413     case HloOpcode::kSign:
1414     case HloOpcode::kSin:
1415     case HloOpcode::kSqrt:
1416     case HloOpcode::kTanh:
1417       CHECK_EQ(new_operands.size(), 1);
1418       clone = CreateUnary(shape, opcode_, new_operands[0]);
1419       break;
1420     // Binary ops.
1421     case HloOpcode::kAdd:
1422     case HloOpcode::kAtan2:
1423     case HloOpcode::kComplex:
1424     case HloOpcode::kDivide:
1425     case HloOpcode::kMultiply:
1426     case HloOpcode::kSubtract:
1427     case HloOpcode::kMaximum:
1428     case HloOpcode::kMinimum:
1429     case HloOpcode::kPower:
1430     case HloOpcode::kRemainder:
1431     case HloOpcode::kAnd:
1432     case HloOpcode::kOr:
1433     case HloOpcode::kXor:
1434     case HloOpcode::kShiftLeft:
1435     case HloOpcode::kShiftRightArithmetic:
1436     case HloOpcode::kShiftRightLogical:
1437       CHECK_EQ(new_operands.size(), 2);
1438       clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
1439       break;
1440     // Ternary ops.
1441     case HloOpcode::kClamp:
1442     case HloOpcode::kSelect:
1443     case HloOpcode::kTupleSelect:
1444       CHECK_EQ(new_operands.size(), 3);
1445       clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
1446                             new_operands[2]);
1447       break;
1448     // Other supported ops.
1449     case HloOpcode::kCall:
1450       clone = CreateCall(shape, new_operands, to_apply());
1451       break;
1452     case HloOpcode::kConvert:
1453       CHECK_EQ(new_operands.size(), 1);
1454       clone = CreateConvert(shape, new_operands[0]);
1455       break;
1456     case HloOpcode::kBitcastConvert:
1457       CHECK_EQ(new_operands.size(), 1);
1458       clone = CreateBitcastConvert(shape, new_operands[0]);
1459       break;
1460     case HloOpcode::kReshape:
1461       CHECK_EQ(new_operands.size(), 1);
1462       clone = CreateReshape(shape, new_operands[0]);
1463       break;
1464     case HloOpcode::kDynamicUpdateSlice:
1465       clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
1466                                        new_operands.subspan(2));
1467       break;
1468     case HloOpcode::kTuple:
1469       clone = CreateTuple(new_operands);
1470       *clone->mutable_shape() = shape;
1471       break;
1472     case HloOpcode::kWhile:
1473       CHECK_EQ(new_operands.size(), 1);
1474       clone =
1475           CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
1476       break;
1477     case HloOpcode::kConditional:
1478       CHECK_EQ(new_operands.size(), branch_count() + 1);
1479       clone = CreateConditional(shape, new_operands[0],
1480                                 absl::MakeSpan(branch_computations()),
1481                                 new_operands.subspan(1));
1482       break;
1483     case HloOpcode::kAfterAll:
1484       if (new_operands.empty()) {
1485         clone = CreateToken();
1486       } else {
1487         clone = CreateAfterAll(new_operands);
1488       }
1489       break;
1490     case HloOpcode::kAddDependency:
1491       CHECK_EQ(new_operands.size(), 2);
1492       clone = CreateAddDependency(new_operands[0], new_operands[1]);
1493       break;
1494     case HloOpcode::kReplicaId:
1495       CHECK_EQ(new_operands.size(), 0);
1496       clone = CreateReplicaId();
1497       break;
1498   }
1499   // SetupDerivedInstruction will setup the precision_config_ field.
1500   SetupDerivedInstruction(clone.get());
1501   clone->set_parent(parent_);
1502   clone->set_raw_backend_config_string(backend_config_);
1503   if (context != nullptr) {
1504     context->MapInstruction(this, clone.get());
1505     clone->ReplaceCalledComputations([&](HloComputation* callee) {
1506       return callee->parent() != context->module()
1507                  ? context->module()->DeepCloneComputation(callee, context)
1508                  : callee;
1509     });
1510   }
1511   return clone;
1512 }
1513 
~HloInstruction()1514 HloInstruction::~HloInstruction() {
1515   // Detach from operands. An instruction may be repeated as an operand. To
1516   // avoid calling RemoveUser twice on the same operand, check before remove.
1517   for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1518     HloInstruction* operand = operands_[operand_num];
1519     if (operand == nullptr) {
1520       continue;
1521     }
1522     if (operand->user_set_.find(this) != operand->user_set_.end()) {
1523       operand->RemoveUser(this);
1524     }
1525     operands_[operand_num] = nullptr;
1526   }
1527 
1528   // Update users. Set `nullptr` to the correpsonding operand slot for users.
1529   for (auto& user : this->users()) {
1530     for (int i = 0; i < user->operand_count(); ++i) {
1531       if (user->operands_[i] == this) {
1532         user->operands_[i] = nullptr;
1533       }
1534     }
1535   }
1536 }
1537 
Clone(const string & suffix,HloCloneContext * context) const1538 std::unique_ptr<HloInstruction> HloInstruction::Clone(
1539     const string& suffix, HloCloneContext* context) const {
1540   std::unique_ptr<HloInstruction> clone =
1541       CloneWithNewOperands(shape_, operands_, context);
1542   if (suffix.empty()) {
1543     clone->name_ = name();
1544   } else {
1545     // If an instruction is cloned multiple times avoid names like
1546     // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
1547     // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
1548     // clone of foo.suffix2 is named foo.suffix3 and so on.
1549     const string dot_suffix = "." + suffix;
1550     size_t index = name().rfind(dot_suffix);
1551     if (index == string::npos) {
1552       // Existing name does not include ".suffix".
1553       clone->name_ = name() + dot_suffix;
1554     } else {
1555       // Existing name includes ".suffix". Determine if substring after
1556       // ".suffix" is numeric and should be replaced with an incremented number.
1557       string after_suffix = name().substr(index + dot_suffix.size());
1558       if (after_suffix.empty()) {
1559         // Existing name ends in ".suffix". New name should end in ".suffix2".
1560         clone->name_ = name() + "2";
1561       } else {
1562         // If names ends with .suffix[0-9]+ then replace with a suffix with the
1563         // numeric value incremented.
1564         int64 numeric_suffix;
1565         if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
1566           clone->name_ =
1567               StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
1568         } else {
1569           // Substring after ".suffix" is non-numeric.
1570           clone->name_ = name() + dot_suffix;
1571         }
1572       }
1573     }
1574   }
1575   return clone;
1576 }
1577 
1578 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const1579 HloInstruction::LatestNonGteAncestorAndIndex() const {
1580   const HloInstruction* hlo = this;
1581   ShapeIndex index;
1582   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1583     index.push_back(hlo->tuple_index());
1584     hlo = hlo->operand(0);
1585   }
1586 
1587   // We built up index in the reverse order from what we want.
1588   std::reverse(index.begin(), index.end());
1589 
1590   return {hlo, index};
1591 }
1592 
LatestNonGteAncestor() const1593 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
1594   const HloInstruction* hlo = this;
1595   while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1596     hlo = hlo->operand(0);
1597   }
1598   return hlo;
1599 }
1600 
operand(int64 i) const1601 const HloInstruction* HloInstruction::operand(int64 i) const {
1602   return operands_[i];
1603 }
1604 
mutable_operand(int64 i)1605 HloInstruction* HloInstruction::mutable_operand(int64 i) {
1606   CHECK(operands_[i] != nullptr);
1607   return operands_[i];
1608 }
1609 
operand_index(const HloInstruction * target) const1610 int64 HloInstruction::operand_index(const HloInstruction* target) const {
1611   for (int64 i = 0; i < operand_count(); ++i) {
1612     if (target == operand(i)) {
1613       return i;
1614     }
1615   }
1616   LOG(FATAL) << "target was not an operand: " << target->ToString();
1617 }
1618 
unique_operands() const1619 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
1620   InstructionVector unique;
1621   absl::flat_hash_set<const HloInstruction*> seen;
1622   for (HloInstruction* operand : operands()) {
1623     if (seen.insert(operand).second) {
1624       unique.push_back(operand);
1625     }
1626   }
1627   return unique;
1628 }
1629 
AddControlDependencyTo(HloInstruction * instruction)1630 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
1631   TF_RET_CHECK(instruction->parent() == parent());
1632   if (!absl::c_linear_search(control_successors_, instruction)) {
1633     control_successors_.push_back(instruction);
1634     TF_RET_CHECK(
1635         !absl::c_linear_search(instruction->control_predecessors_, this));
1636     instruction->control_predecessors_.push_back(this);
1637   }
1638   return Status::OK();
1639 }
1640 
RemoveControlDependencyTo(HloInstruction * instruction)1641 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
1642   TF_RET_CHECK(instruction->parent() == parent());
1643   TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
1644   TF_RETURN_IF_ERROR(
1645       EraseElementFromVector(&instruction->control_predecessors_, this));
1646   return Status::OK();
1647 }
1648 
DropAllControlDeps()1649 Status HloInstruction::DropAllControlDeps() {
1650   for (auto* ctrl_succ : control_successors_) {
1651     TF_RETURN_IF_ERROR(
1652         EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
1653   }
1654   for (auto* ctrl_pred : control_predecessors_) {
1655     TF_RETURN_IF_ERROR(
1656         EraseElementFromVector(&ctrl_pred->control_successors_, this));
1657   }
1658   control_successors_.clear();
1659   control_predecessors_.clear();
1660   return Status::OK();
1661 }
1662 
CopyAllControlDepsFrom(const HloInstruction * inst)1663 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
1664   for (auto* ctrl_pred : inst->control_predecessors()) {
1665     TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
1666   }
1667 
1668   for (auto* ctrl_succ : inst->control_successors()) {
1669     TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
1670   }
1671 
1672   return Status::OK();
1673 }
1674 
AppendOperand(HloInstruction * operand)1675 void HloInstruction::AppendOperand(HloInstruction* operand) {
1676   operands_.push_back(operand);
1677   operand->AddUser(this);
1678 }
1679 
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)1680 void HloInstruction::RemoveOperandsAtAscendingIndices(
1681     absl::Span<const int> ascending_indices) {
1682   if (ascending_indices.empty()) {
1683     return;
1684   }
1685   int next_index = 0;
1686   int removed_count = 0;
1687   for (int to_remove : ascending_indices) {
1688     while (next_index < to_remove) {
1689       operands_[next_index - removed_count] = operands_[next_index];
1690       ++next_index;
1691     }
1692     CHECK_LT(to_remove, operands_.size());
1693     ++removed_count;
1694     ++next_index;
1695   }
1696   while (next_index < operands_.size()) {
1697     operands_[next_index - removed_count] = operands_[next_index];
1698     ++next_index;
1699   }
1700   CHECK_EQ(removed_count, ascending_indices.size());
1701   operands_.resize(operands_.size() - removed_count);
1702 }
1703 
AddUser(HloInstruction * user)1704 void HloInstruction::AddUser(HloInstruction* user) {
1705   if (!ContainsKey(user_set_, user)) {
1706     user_set_.insert(user);
1707     users_.push_back(user);
1708   }
1709 }
1710 
HasConstantOperand() const1711 bool HloInstruction::HasConstantOperand() const {
1712   for (const HloInstruction* operand : operands_) {
1713     if (operand->IsConstant()) {
1714       return true;
1715     }
1716   }
1717   return false;
1718 }
1719 
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1720 bool HloInstruction::IdenticalSlowPath(
1721     const HloInstruction& other,
1722     const std::function<bool(const HloComputation*, const HloComputation*)>&
1723         eq_computations) const {
1724   // Perform opcode specific checks.
1725   switch (opcode()) {
1726     // The result of these instructions only depend upon their opcode and
1727     // operands.
1728     case HloOpcode::kAbs:
1729     case HloOpcode::kAtan2:
1730     case HloOpcode::kAdd:
1731     case HloOpcode::kBitcast:
1732     case HloOpcode::kBitcastConvert:
1733     case HloOpcode::kCeil:
1734     case HloOpcode::kClamp:
1735     case HloOpcode::kClz:
1736     case HloOpcode::kComplex:
1737     case HloOpcode::kConvert:
1738     case HloOpcode::kCopy:
1739     case HloOpcode::kCos:
1740     case HloOpcode::kDivide:
1741     case HloOpcode::kDynamicUpdateSlice:
1742     case HloOpcode::kExp:
1743     case HloOpcode::kExpm1:
1744     case HloOpcode::kFloor:
1745     case HloOpcode::kImag:
1746     case HloOpcode::kIsFinite:
1747     case HloOpcode::kLog:
1748     case HloOpcode::kLog1p:
1749     case HloOpcode::kAnd:
1750     case HloOpcode::kNot:
1751     case HloOpcode::kOr:
1752     case HloOpcode::kXor:
1753     case HloOpcode::kMaximum:
1754     case HloOpcode::kMinimum:
1755     case HloOpcode::kMultiply:
1756     case HloOpcode::kNegate:
1757     case HloOpcode::kPower:
1758     case HloOpcode::kReal:
1759     case HloOpcode::kRemainder:
1760     case HloOpcode::kReshape:
1761     case HloOpcode::kReplicaId:
1762     case HloOpcode::kRoundNearestAfz:
1763     case HloOpcode::kRsqrt:
1764     case HloOpcode::kSelect:
1765     case HloOpcode::kShiftLeft:
1766     case HloOpcode::kShiftRightArithmetic:
1767     case HloOpcode::kShiftRightLogical:
1768     case HloOpcode::kSign:
1769     case HloOpcode::kSin:
1770     case HloOpcode::kSqrt:
1771     case HloOpcode::kSubtract:
1772     case HloOpcode::kTanh:
1773     case HloOpcode::kTuple:
1774     case HloOpcode::kTupleSelect:
1775       return true;
1776 
1777     // This opcode has complex or special behavior so just return false.
1778     case HloOpcode::kAfterAll:
1779     case HloOpcode::kAddDependency:
1780       return false;
1781 
1782     // Remaining instructions with special values.
1783     case HloOpcode::kCall:
1784       return eq_computations(to_apply(), other.to_apply());
1785     case HloOpcode::kConditional:
1786       for (int j = 0; j < branch_count(); ++j) {
1787         if (!eq_computations(branch_computation(j),
1788                              other.branch_computation(j))) {
1789           return false;
1790         }
1791       }
1792       return true;
1793     case HloOpcode::kWhile:
1794       return (eq_computations(while_body(), other.while_body()) &&
1795               eq_computations(while_condition(), other.while_condition()));
1796 
1797     // Ops migrated to subclasses should never come to this line.
1798     // TODO(b/80131774): Remove this switch when migration is complete.
1799     case HloOpcode::kBatchNormTraining:
1800     case HloOpcode::kBatchNormInference:
1801     case HloOpcode::kBatchNormGrad:
1802     case HloOpcode::kFft:
1803     case HloOpcode::kCompare:
1804     case HloOpcode::kSend:
1805     case HloOpcode::kSendDone:
1806     case HloOpcode::kRecv:
1807     case HloOpcode::kRecvDone:
1808     case HloOpcode::kReverse:
1809     case HloOpcode::kConcatenate:
1810     case HloOpcode::kReduce:
1811     case HloOpcode::kSort:
1812     case HloOpcode::kTranspose:
1813     case HloOpcode::kBroadcast:
1814     case HloOpcode::kMap:
1815     case HloOpcode::kSlice:
1816     case HloOpcode::kConstant:
1817     case HloOpcode::kIota:
1818     case HloOpcode::kTrace:
1819     case HloOpcode::kFusion:
1820     case HloOpcode::kRng:
1821     case HloOpcode::kParameter:
1822     case HloOpcode::kGetTupleElement:
1823     case HloOpcode::kReducePrecision:
1824     case HloOpcode::kInfeed:
1825     case HloOpcode::kOutfeed:
1826     case HloOpcode::kAllReduce:
1827     case HloOpcode::kAllToAll:
1828     case HloOpcode::kCollectivePermute:
1829     case HloOpcode::kConvolution:
1830     case HloOpcode::kCustomCall:
1831     case HloOpcode::kReduceWindow:
1832     case HloOpcode::kSelectAndScatter:
1833     case HloOpcode::kPad:
1834     case HloOpcode::kDynamicSlice:
1835     case HloOpcode::kGather:
1836     case HloOpcode::kScatter:
1837     case HloOpcode::kDot:
1838     case HloOpcode::kDomain:
1839     case HloOpcode::kGetDimensionSize:
1840     case HloOpcode::kTriangularSolve:
1841     case HloOpcode::kCholesky:
1842       LOG(FATAL) << "Base class impl called for opcode with subclass: "
1843                  << opcode();
1844   }
1845   return false;
1846 }
1847 
HashOperand(const HloInstruction * hlo)1848 static uint64 HashOperand(const HloInstruction* hlo) {
1849   return ShapeUtil::Hash(hlo->shape());
1850 }
1851 
Hash(const std::function<uint64 (const HloInstruction *)> & hash_operand) const1852 uint64 HloInstruction::Hash(
1853     const std::function<uint64(const HloInstruction*)>& hash_operand) const {
1854   using tensorflow::Hash64Combine;
1855 
1856   uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
1857   hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape()));
1858 
1859   if (!IsCrossModuleAllReduce()) {
1860     if (!operands().empty()) {
1861       for (size_t i = 0; i < operands().size(); ++i) {
1862         hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
1863       }
1864     }
1865   }
1866 
1867   hash_value = Hash64Combine(hash_value, InnerHash());
1868   return hash_value;
1869 }
1870 
Hash() const1871 uint64 HloInstruction::Hash() const {
1872   // Use HashOperand as an argument to prevent non-termination.
1873   return Hash(HashOperand);
1874 }
1875 
InnerHash() const1876 uint64 HloInstruction::InnerHash() const { return 13; }
1877 
RemoveUser(HloInstruction * user)1878 void HloInstruction::RemoveUser(HloInstruction* user) {
1879   auto set_it = user_set_.find(user);
1880   CHECK(set_it != user_set_.end());
1881   user_set_.erase(set_it);
1882   // This is linear in the number of the users, but a vector provides a stable
1883   // iteration order and much faster traversal.
1884   auto vec_it = absl::c_find(users_, user);
1885   CHECK(vec_it != users_.end());
1886   users_.erase(vec_it);
1887 }
1888 
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)1889 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
1890                                       HloInstruction* new_producer) {
1891   TF_RET_CHECK(
1892       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
1893       << "this shape: " << ShapeUtil::HumanString(shape())
1894       << ", replacement shape: "
1895       << ShapeUtil::HumanString(new_producer->shape());
1896   return ReplaceUseWithDifferentShape(user, new_producer);
1897 }
1898 
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)1899 Status HloInstruction::ReplaceUseWithDifferentShape(
1900     HloInstruction* user, HloInstruction* new_producer) {
1901   VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
1902           << " with " << new_producer->name();
1903 
1904   RemoveUser(user);
1905 
1906   TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
1907   std::replace(user->operands_.begin(), user->operands_.end(), this,
1908                new_producer);
1909   new_producer->AddUser(user);
1910   if (user->opcode() == HloOpcode::kFusion) {
1911     TF_RETURN_IF_ERROR(
1912         Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
1913   }
1914   return Status::OK();
1915 }
1916 
ReplaceOperandWith(int64 operand_num,HloInstruction * new_operand)1917 Status HloInstruction::ReplaceOperandWith(int64 operand_num,
1918                                           HloInstruction* new_operand) {
1919   auto old_operand = operand(operand_num);
1920   TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
1921                                                         new_operand->shape()))
1922       << old_operand->shape() << " is not compatible with "
1923       << new_operand->shape();
1924   return ReplaceOperandWithDifferentShape(operand_num, new_operand);
1925 }
1926 
ReplaceOperandWithDifferentShape(int64 operand_num,HloInstruction * new_operand)1927 Status HloInstruction::ReplaceOperandWithDifferentShape(
1928     int64 operand_num, HloInstruction* new_operand) {
1929   TF_RET_CHECK(operand_num >= 0);
1930   TF_RET_CHECK(operand_num < operand_count());
1931   HloInstruction* old_operand = mutable_operand(operand_num);
1932   if (old_operand == new_operand) {
1933     return Status::OK();
1934   }
1935 
1936   operands_[operand_num] = new_operand;
1937 
1938   VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
1939           << new_operand->name() << ", was " << old_operand->name();
1940 
1941   if (!absl::c_linear_search(operands_, old_operand)) {
1942     old_operand->RemoveUser(this);
1943   }
1944   new_operand->AddUser(this);
1945   return Status::OK();
1946 }
1947 
ReplaceAllUsesWith(HloInstruction * new_producer)1948 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
1949   TF_RET_CHECK(
1950       ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
1951       << shape() << " is not compatible with " << new_producer->shape();
1952   return ReplaceAllUsesWithDifferentShape(new_producer);
1953 }
1954 
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)1955 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
1956     HloInstruction* new_producer) {
1957   bool new_producer_is_user = false;
1958   for (HloInstruction* user : users()) {
1959     if (user == new_producer) {
1960       // It's possible that new_producer is a user of this instruction as might
1961       // be the case when replacing an instruction with a kCopy of itself. In
1962       // this case, don't do the replacement to avoid creating a cycle in the
1963       // graph. new_producer remains the only user of this instruction.
1964       new_producer_is_user = true;
1965     } else {
1966       std::replace(user->operands_.begin(), user->operands_.end(), this,
1967                    new_producer);
1968       new_producer->AddUser(user);
1969       if (user->opcode() == HloOpcode::kFusion) {
1970         TF_RETURN_IF_ERROR(
1971             Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
1972       }
1973     }
1974   }
1975   users_.clear();
1976   user_set_.clear();
1977   if (new_producer_is_user) {
1978     AddUser(new_producer);
1979   }
1980   if (parent_ && parent_->root_instruction() == this) {
1981     parent_->set_root_instruction(new_producer,
1982                                   /*accept_different_shape=*/true);
1983   }
1984 
1985   return Status::OK();
1986 }
1987 
to_apply() const1988 HloComputation* HloInstruction::to_apply() const {
1989   switch (opcode_) {
1990     case HloOpcode::kCall:
1991     case HloOpcode::kMap:
1992     case HloOpcode::kReduceWindow:
1993     case HloOpcode::kReduce:
1994     case HloOpcode::kAllReduce:
1995     case HloOpcode::kScatter:
1996     case HloOpcode::kSort:
1997       CHECK_EQ(called_computations_.size(), 1);
1998       return called_computations_[0];
1999     default:
2000       LOG(FATAL) << "Invalid opcode for to_apply(): "
2001                  << HloOpcodeString(opcode());
2002   }
2003 }
2004 
set_to_apply(HloComputation * computation)2005 void HloInstruction::set_to_apply(HloComputation* computation) {
2006   // Don't allow changing the computation for fused instructions so we don't
2007   // have to recompute called_instructions for the entire fusion instruction.
2008   CHECK(!IsFused());
2009   switch (opcode_) {
2010     case HloOpcode::kCall:
2011     case HloOpcode::kMap:
2012     case HloOpcode::kReduceWindow:
2013     case HloOpcode::kReduce:
2014     case HloOpcode::kAllReduce:
2015     case HloOpcode::kScatter:
2016     case HloOpcode::kSort:
2017       CHECK_EQ(called_computations_.size(), 1);
2018       called_computations_[0] = computation;
2019       break;
2020     default:
2021       LOG(FATAL) << "Invalid opcode for to_apply(): "
2022                  << HloOpcodeString(opcode());
2023   }
2024 }
2025 
while_condition() const2026 HloComputation* HloInstruction::while_condition() const {
2027   CHECK_EQ(HloOpcode::kWhile, opcode_);
2028   return called_computations_[kConditionComputationIndex];
2029 }
2030 
while_body() const2031 HloComputation* HloInstruction::while_body() const {
2032   CHECK_EQ(HloOpcode::kWhile, opcode_);
2033   return called_computations_[kBodyComputationIndex];
2034 }
2035 
set_while_condition(HloComputation * computation)2036 void HloInstruction::set_while_condition(HloComputation* computation) {
2037   // Don't allow changing the computation for fused instructions so we don't
2038   // have to recompute called_instructions for the entire fusion instruction.
2039   CHECK(!IsFused());
2040   CHECK_EQ(HloOpcode::kWhile, opcode_);
2041   called_computations_[kConditionComputationIndex] = computation;
2042 }
2043 
set_while_body(HloComputation * computation)2044 void HloInstruction::set_while_body(HloComputation* computation) {
2045   // Don't allow changing the computation for fused instructions so we don't
2046   // have to recompute called_instructions for the entire fusion instruction.
2047   CHECK(!IsFused());
2048   CHECK_EQ(HloOpcode::kWhile, opcode_);
2049   called_computations_[kBodyComputationIndex] = computation;
2050 }
2051 
while_init() const2052 HloInstruction* HloInstruction::while_init() const {
2053   CHECK_EQ(HloOpcode::kWhile, opcode_);
2054   return operands_[0];
2055 }
2056 
true_computation() const2057 HloComputation* HloInstruction::true_computation() const {
2058   CHECK_EQ(HloOpcode::kConditional, opcode_);
2059   CHECK_EQ(PRED, operand(0)->shape().element_type());
2060   return called_computations_[kTrueComputationIndex];
2061 }
2062 
false_computation() const2063 HloComputation* HloInstruction::false_computation() const {
2064   CHECK_EQ(HloOpcode::kConditional, opcode_);
2065   CHECK_EQ(PRED, operand(0)->shape().element_type());
2066   return called_computations_[kFalseComputationIndex];
2067 }
2068 
branch_computations() const2069 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2070     const {
2071   CHECK(HloOpcode::kConditional == opcode_);
2072   return called_computations_;
2073 }
2074 
branch_count() const2075 int HloInstruction::branch_count() const {
2076   CHECK(HloOpcode::kConditional == opcode_);
2077   return called_computations_.size();
2078 }
2079 
branch_computation(int b) const2080 HloComputation* HloInstruction::branch_computation(int b) const {
2081   CHECK(HloOpcode::kConditional == opcode_);
2082   CHECK_GE(b, 0);
2083   CHECK_LT(b, called_computations_.size());
2084   return called_computations_[b];
2085 }
2086 
set_branch_computation(int b,HloComputation * computation)2087 void HloInstruction::set_branch_computation(int b,
2088                                             HloComputation* computation) {
2089   // Don't allow changing the computation for fused instructions so we don't
2090   // have to recompute called_instructions for the entire fusion instruction.
2091   CHECK(!IsFused());
2092   CHECK_EQ(HloOpcode::kConditional, opcode_);
2093   called_computations_[b] = computation;
2094 }
2095 
SignatureString() const2096 string HloInstruction::SignatureString() const {
2097   string operands =
2098       StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
2099         StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2100       });
2101   return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2102 }
2103 
2104 namespace {
2105 
PrintName(const string & name,const HloPrintOptions & options)2106 string PrintName(const string& name, const HloPrintOptions& options) {
2107   return StrCat(options.print_percent() ? "%" : "", name);
2108 }
2109 
2110 }  // namespace
2111 
ToString(const HloPrintOptions & options) const2112 string HloInstruction::ToString(const HloPrintOptions& options) const {
2113   CanonicalNameMap new_map;
2114   return ToStringWithCanonicalNameMap(options, &new_map);
2115 }
2116 
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2117 bool HloInstruction::IsElementwiseImpl(
2118     const absl::optional<int64>& operand_idx) const {
2119   switch (opcode_) {
2120     // Unary elementwise operations.
2121     case HloOpcode::kAbs:
2122     case HloOpcode::kRoundNearestAfz:
2123     case HloOpcode::kCeil:
2124     case HloOpcode::kClz:
2125     case HloOpcode::kConvert:
2126     case HloOpcode::kBitcastConvert:
2127     case HloOpcode::kCopy:
2128     case HloOpcode::kCos:
2129     case HloOpcode::kExp:
2130     case HloOpcode::kExpm1:
2131     case HloOpcode::kFloor:
2132     case HloOpcode::kImag:
2133     case HloOpcode::kIsFinite:
2134     case HloOpcode::kLog:
2135     case HloOpcode::kLog1p:
2136     case HloOpcode::kNot:
2137     case HloOpcode::kNegate:
2138     case HloOpcode::kReal:
2139     case HloOpcode::kReducePrecision:
2140     case HloOpcode::kRsqrt:
2141     case HloOpcode::kSign:
2142     case HloOpcode::kSin:
2143     case HloOpcode::kSqrt:
2144     case HloOpcode::kTanh:
2145       CHECK_EQ(1, operand_count());
2146       return true;
2147 
2148     // Binary elementwise operations, the same as in IsElementwiseBinary().
2149     case HloOpcode::kAdd:
2150     case HloOpcode::kAtan2:
2151     case HloOpcode::kCompare:
2152     case HloOpcode::kComplex:
2153     case HloOpcode::kDivide:
2154     case HloOpcode::kMaximum:
2155     case HloOpcode::kMinimum:
2156     case HloOpcode::kMultiply:
2157     case HloOpcode::kPower:
2158     case HloOpcode::kRemainder:
2159     case HloOpcode::kSubtract:
2160     case HloOpcode::kAnd:
2161     case HloOpcode::kOr:
2162     case HloOpcode::kXor:
2163     case HloOpcode::kShiftLeft:
2164     case HloOpcode::kShiftRightArithmetic:
2165     case HloOpcode::kShiftRightLogical:
2166       CHECK_EQ(2, operand_count());
2167       return true;
2168 
2169     // Ternary elementwise operations.
2170     case HloOpcode::kSelect:
2171     case HloOpcode::kClamp:
2172       return true;
2173 
2174     case HloOpcode::kDynamicUpdateSlice:
2175       return operand_idx.has_value() && operand_idx.value() == 0;
2176 
2177     default:
2178       return false;
2179   }
2180 }
2181 
IsCrossModuleAllReduce() const2182 bool HloInstruction::IsCrossModuleAllReduce() const {
2183   return opcode() == HloOpcode::kAllReduce && all_reduce_id();
2184 }
2185 
IsCrossReplicaAllReduce() const2186 bool HloInstruction::IsCrossReplicaAllReduce() const {
2187   return opcode() == HloOpcode::kAllReduce && !all_reduce_id();
2188 }
2189 
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2190 string HloInstruction::ToStringWithCanonicalNameMap(
2191     const HloPrintOptions& options,
2192     CanonicalNameMap* canonical_name_map) const {
2193   string result = "";
2194 
2195   // Logic to print the instruction name (e.g. "%foo = ").
2196   if (options.canonicalize_instruction_names()) {
2197     if (options.is_in_nested_computation()) {
2198       // If we are canonicalizing instruction names and this is a top-level
2199       // HloInstruction::ToString() call, don't print an instruction name.
2200       StrAppend(&result,
2201                 PrintName(canonical_name_map->LookupOrInsert(name()), options),
2202                 " = ");
2203     }
2204   } else {
2205     StrAppend(&result, PrintName(name(), options), " = ");
2206   }
2207 
2208   // Print opcode, operand(s) and shape.
2209   StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ",
2210             HloOpcodeString(opcode()), "(",
2211             OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
2212             ")");
2213 
2214   // Print additional attributes. If an instruction contains a subcomputation,
2215   // the subcomputation is also printed here.
2216   for (const string& extra : ExtraAttributesToString(options)) {
2217     StrAppend(&result, ", ", extra);
2218   }
2219 
2220   if (options.print_metadata() &&
2221       (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
2222        !metadata_.source_file().empty())) {
2223     StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
2224   }
2225   if (options.print_backend_config() && !backend_config_.empty()) {
2226     StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
2227   }
2228   return result;
2229 }
2230 
OperandsToString(const HloPrintOptions & options) const2231 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
2232   CanonicalNameMap new_map;
2233   return OperandsToStringWithCanonicalNameMap(options, &new_map);
2234 }
2235 
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2236 string HloInstruction::OperandsToStringWithCanonicalNameMap(
2237     const HloPrintOptions& options,
2238     CanonicalNameMap* canonical_name_map) const {
2239   string operands;
2240   absl::Span<HloInstruction* const> slice(operands_);
2241   const int64 kMaxOperandsToShowIfCompact = 4;
2242   if (options.compact_operands() &&
2243       slice.size() > kMaxOperandsToShowIfCompact) {
2244     slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
2245   }
2246   operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
2247     // If operand is already been deleted, put `null` to the string output.
2248     if (operand == nullptr) {
2249       StrAppend(out, "null ");
2250       return;
2251     }
2252     std::vector<string> str;
2253     if (options.print_operand_shape()) {
2254       str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
2255     }
2256 
2257     // In a top-level HloInstruction::ToString() call, the operand name is not
2258     // part of the canonical string.
2259     if (options.canonicalize_instruction_names() &&
2260         options.is_in_nested_computation()) {
2261       str.push_back(PrintName(
2262           canonical_name_map->LookupOrInsert(operand->name()), options));
2263     } else if (options.print_operand_names()) {
2264       str.push_back(PrintName(operand->name(), options));
2265     }
2266     StrAppend(out, StrJoin(str, " "));
2267   });
2268   const int64 remaining = operands_.size() - slice.size();
2269   if (slice.size() != operands_.size()) {
2270     StrAppend(&operands, ", ...(+", remaining, ")");
2271   }
2272   return operands;
2273 }
2274 
ExtraAttributesToString(const HloPrintOptions & options) const2275 std::vector<string> HloInstruction::ExtraAttributesToString(
2276     const HloPrintOptions& options) const {
2277   std::vector<string> extra = ExtraAttributesToStringImpl(options);
2278 
2279   if (options.print_subcomputation_mode() ==
2280       HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
2281     if (opcode() == HloOpcode::kWhile) {
2282       extra.push_back(
2283           StrCat("condition=", PrintName(while_condition()->name(), options)));
2284       extra.push_back(
2285           StrCat("body=", PrintName(while_body()->name(), options)));
2286     } else if (opcode() == HloOpcode::kSelectAndScatter) {
2287       extra.push_back(StrCat("select=", PrintName(select()->name(), options)));
2288       extra.push_back(
2289           StrCat("scatter=", PrintName(scatter()->name(), options)));
2290     } else if (opcode() == HloOpcode::kConditional) {
2291       if (operand(0)->shape().element_type() == PRED) {
2292         extra.push_back(StrCat("true_computation=",
2293                                PrintName(true_computation()->name(), options)));
2294         extra.push_back(
2295             StrCat("false_computation=",
2296                    PrintName(false_computation()->name(), options)));
2297       } else {
2298         extra.push_back(StrCat(
2299             "branch_computations={",
2300             StrJoin(branch_computations(), ", ",
2301                     [&](string* out, const HloComputation* computation) {
2302                       StrAppend(out, PrintName(computation->name(), options));
2303                     }),
2304             "}"));
2305       }
2306     } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
2307                opcode() == HloOpcode::kReduceWindow ||
2308                opcode() == HloOpcode::kReduce ||
2309                opcode() == HloOpcode::kAllReduce ||
2310                opcode() == HloOpcode::kScatter ||
2311                opcode() == HloOpcode::kSort) {
2312       extra.push_back(
2313           StrCat("to_apply=", PrintName(to_apply()->name(), options)));
2314     } else if (!called_computations().empty()) {
2315       extra.push_back(StrCat(
2316           "calls=",
2317           StrJoin(called_computations(), ", ",
2318                   [&](string* out, const HloComputation* computation) {
2319                     StrAppend(out, PrintName(computation->name(), options));
2320                   })));
2321     }
2322   } else if (options.print_subcomputation_mode() ==
2323              HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
2324     HloPrintOptions new_options = options;
2325     new_options.set_is_in_nested_computation(true);
2326     switch (opcode()) {
2327       case HloOpcode::kWhile:
2328         extra.push_back(
2329             StrCat("condition=\n", while_condition()->ToString(new_options)));
2330         extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
2331         break;
2332       case HloOpcode::kSelectAndScatter:
2333         extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
2334         extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
2335         break;
2336       case HloOpcode::kConditional:
2337         if (operand(0)->shape().element_type() == PRED) {
2338           extra.push_back(StrCat("true_computation=\n",
2339                                  true_computation()->ToString(new_options)));
2340           extra.push_back(StrCat("false_computation=\n",
2341                                  false_computation()->ToString(new_options)));
2342         } else {
2343           extra.push_back(StrCat(
2344               "branch_computations={\n",
2345               StrJoin(branch_computations(), ",\n",
2346                       [&](string* out, const HloComputation* computation) {
2347                         StrAppend(out, computation->ToString(new_options));
2348                       }),
2349               "\n}"));
2350         }
2351         break;
2352       case HloOpcode::kCall:
2353       case HloOpcode::kMap:
2354       case HloOpcode::kReduceWindow:
2355       case HloOpcode::kReduce:
2356       case HloOpcode::kAllReduce:
2357       case HloOpcode::kScatter:
2358       case HloOpcode::kSort:
2359         extra.push_back(
2360             StrCat("to_apply=\n", to_apply()->ToString(new_options)));
2361         break;
2362       default:
2363         if (!called_computations().empty()) {
2364           extra.push_back(StrCat(
2365               "calls=\n",
2366               StrJoin(called_computations(), ", ",
2367                       [&](string* out, const HloComputation* computation) {
2368                         StrAppend(out, computation->ToString(new_options));
2369                       })));
2370         }
2371         break;
2372     }
2373   }
2374 
2375   if (has_sharding()) {
2376     extra.push_back(StrCat("sharding=", sharding().ToString()));
2377   }
2378   if (options.print_control_dependencies() && !control_predecessors_.empty()) {
2379     extra.push_back(StrCat("control-predecessors={",
2380                            StrJoin(control_predecessors_, ", ",
2381                                    [&](string* out, HloInstruction* pre) {
2382                                      StrAppend(out,
2383                                                PrintName(pre->name(), options));
2384                                    }),
2385                            "}"));
2386   }
2387 
2388   return extra;
2389 }
2390 
ToShortString() const2391 string HloInstruction::ToShortString() const {
2392   return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
2393                 StrJoin(operands_, ", ",
2394                         [](string* out, HloInstruction* operand) {
2395                           StrAppend(out, "%", operand->name());
2396                         }),
2397                 ")");
2398 }
2399 
ToProto() const2400 HloInstructionProto HloInstruction::ToProto() const {
2401   HloInstructionProto proto;
2402   CHECK(unique_id_ != -1)
2403       << "This instruction does not have a valid id. Please make sure the "
2404          "instruction is inside a module before dumping it.";
2405   proto.set_id(unique_id_);
2406   proto.set_name(name_);
2407   proto.set_opcode(HloOpcodeString(opcode_));
2408   *proto.mutable_shape() = shape_.ToProto();
2409   for (const HloInstruction* operand : operands_) {
2410     proto.add_operand_ids(operand->unique_id());
2411   }
2412   for (const HloInstruction* control : control_predecessors_) {
2413     proto.add_control_predecessor_ids(control->unique_id());
2414   }
2415 
2416   *proto.mutable_metadata() = metadata_;
2417   proto.set_backend_config(backend_config_);
2418   if (opcode() != HloOpcode::kFusion) {
2419     for (const HloComputation* computation : called_computations_) {
2420       proto.add_called_computation_ids(computation->unique_id());
2421     }
2422   }
2423 
2424   if (has_sharding()) {
2425     *proto.mutable_sharding() = sharding().ToProto();
2426   }
2427 
2428   return proto;
2429 }
2430 
ToCategory() const2431 string HloInstruction::ToCategory() const {
2432   if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
2433       opcode() == HloOpcode::kReshape) {
2434     return "data formatting";
2435   }
2436 
2437   if (IsElementwise()) {
2438     return "non-fusion elementwise";
2439   }
2440 
2441   return HloOpcodeString(opcode());
2442 }
2443 
tracing() const2444 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
2445 
set_tracing(HloInstruction * trace_instruction)2446 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
2447   trace_instruction_ = trace_instruction;
2448 }
2449 
IsFused() const2450 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
2451 
IsFusible() const2452 bool HloInstruction::IsFusible() const {
2453   // Instructions which are traced should not be fused.
2454   if (tracing()) {
2455     return false;
2456   }
2457   // Some kinds of instructions don't make sense to fuse.
2458   switch (opcode_) {
2459     case HloOpcode::kDomain:
2460     case HloOpcode::kParameter:
2461       return false;
2462     // Side effecting instrutions cannot be fused.
2463     default:
2464       return !HasSideEffect();
2465   }
2466 }
2467 
HloInstruction(HloOpcode opcode,const Shape & shape)2468 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
2469     : unique_id_(-1),
2470       opcode_(opcode),
2471       shape_(shape),
2472       name_(HloOpcodeString(opcode)) {
2473   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
2474 }
2475 
2476 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)2477 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
2478   switch (opcode_) {
2479     case HloOpcode::kAbs:
2480       return visitor->HandleAbs(this);
2481     case HloOpcode::kAtan2:
2482       return visitor->HandleAtan2(this);
2483     case HloOpcode::kRoundNearestAfz:
2484       return visitor->HandleRound(this);
2485     case HloOpcode::kBatchNormTraining:
2486       return visitor->HandleBatchNormTraining(this);
2487     case HloOpcode::kBatchNormInference:
2488       return visitor->HandleBatchNormInference(this);
2489     case HloOpcode::kBatchNormGrad:
2490       return visitor->HandleBatchNormGrad(this);
2491     case HloOpcode::kSign:
2492       return visitor->HandleSign(this);
2493     case HloOpcode::kConstant:
2494       return visitor->HandleConstant(this);
2495     case HloOpcode::kGetTupleElement:
2496       return visitor->HandleGetTupleElement(this);
2497     case HloOpcode::kParameter:
2498       return visitor->HandleParameter(this);
2499     case HloOpcode::kCompare:
2500       return visitor->HandleCompare(this);
2501     case HloOpcode::kComplex:
2502       return visitor->HandleComplex(this);
2503     case HloOpcode::kAdd:
2504       return visitor->HandleAdd(this);
2505     case HloOpcode::kDivide:
2506       return visitor->HandleDivide(this);
2507     case HloOpcode::kSubtract:
2508       return visitor->HandleSubtract(this);
2509     case HloOpcode::kMaximum:
2510       return visitor->HandleMaximum(this);
2511     case HloOpcode::kMinimum:
2512       return visitor->HandleMinimum(this);
2513     case HloOpcode::kAnd:
2514       return visitor->HandleAnd(this);
2515     case HloOpcode::kOr:
2516       return visitor->HandleOr(this);
2517     case HloOpcode::kXor:
2518       return visitor->HandleXor(this);
2519     case HloOpcode::kShiftLeft:
2520       return visitor->HandleShiftLeft(this);
2521     case HloOpcode::kShiftRightArithmetic:
2522       return visitor->HandleShiftRightArithmetic(this);
2523     case HloOpcode::kShiftRightLogical:
2524       return visitor->HandleShiftRightLogical(this);
2525     case HloOpcode::kConcatenate:
2526       return visitor->HandleConcatenate(this);
2527     case HloOpcode::kConvert:
2528       return visitor->HandleConvert(this);
2529     case HloOpcode::kBitcastConvert:
2530       return visitor->HandleBitcastConvert(this);
2531     case HloOpcode::kCopy:
2532       return visitor->HandleCopy(this);
2533     case HloOpcode::kMultiply:
2534       return visitor->HandleMultiply(this);
2535     case HloOpcode::kDot:
2536       return visitor->HandleDot(this);
2537     case HloOpcode::kPower:
2538       return visitor->HandlePower(this);
2539     case HloOpcode::kRemainder:
2540       return visitor->HandleRemainder(this);
2541     case HloOpcode::kSelect:
2542       return visitor->HandleSelect(this);
2543     case HloOpcode::kTupleSelect:
2544       return visitor->HandleTupleSelect(this);
2545     case HloOpcode::kConvolution:
2546       return visitor->HandleConvolution(this);
2547     case HloOpcode::kFft:
2548       return visitor->HandleFft(this);
2549     case HloOpcode::kAllReduce:
2550       return visitor->HandleAllReduce(this);
2551     case HloOpcode::kAllToAll:
2552       return visitor->HandleAllToAll(this);
2553     case HloOpcode::kCollectivePermute:
2554       return visitor->HandleCollectivePermute(this);
2555     case HloOpcode::kReplicaId:
2556       return visitor->HandleReplicaId(this);
2557     case HloOpcode::kTuple:
2558       return visitor->HandleTuple(this);
2559     case HloOpcode::kMap:
2560       return visitor->HandleMap(this);
2561     case HloOpcode::kClamp:
2562       return visitor->HandleClamp(this);
2563     case HloOpcode::kReduce:
2564       return visitor->HandleReduce(this);
2565     case HloOpcode::kReduceWindow:
2566       return visitor->HandleReduceWindow(this);
2567     case HloOpcode::kSelectAndScatter:
2568       return visitor->HandleSelectAndScatter(this);
2569     case HloOpcode::kNegate:
2570       return visitor->HandleNegate(this);
2571     case HloOpcode::kExp:
2572       return visitor->HandleExp(this);
2573     case HloOpcode::kExpm1:
2574       return visitor->HandleExpm1(this);
2575     case HloOpcode::kFloor:
2576       return visitor->HandleFloor(this);
2577     case HloOpcode::kCeil:
2578       return visitor->HandleCeil(this);
2579     case HloOpcode::kClz:
2580       return visitor->HandleClz(this);
2581     case HloOpcode::kLog:
2582       return visitor->HandleLog(this);
2583     case HloOpcode::kLog1p:
2584       return visitor->HandleLog1p(this);
2585     case HloOpcode::kTanh:
2586       return visitor->HandleTanh(this);
2587     case HloOpcode::kCos:
2588       return visitor->HandleCos(this);
2589     case HloOpcode::kSin:
2590       return visitor->HandleSin(this);
2591     case HloOpcode::kSqrt:
2592       return visitor->HandleSqrt(this);
2593     case HloOpcode::kRsqrt:
2594       return visitor->HandleRsqrt(this);
2595     case HloOpcode::kReal:
2596       return visitor->HandleReal(this);
2597     case HloOpcode::kImag:
2598       return visitor->HandleImag(this);
2599     case HloOpcode::kIsFinite:
2600       return visitor->HandleIsFinite(this);
2601     case HloOpcode::kNot:
2602       return visitor->HandleNot(this);
2603     case HloOpcode::kBitcast:
2604       return visitor->HandleBitcast(this);
2605     case HloOpcode::kBroadcast:
2606       return visitor->HandleBroadcast(this);
2607     case HloOpcode::kPad:
2608       return visitor->HandlePad(this);
2609     case HloOpcode::kReshape:
2610       return visitor->HandleReshape(this);
2611     case HloOpcode::kTranspose:
2612       return visitor->HandleTranspose(this);
2613     case HloOpcode::kReverse:
2614       return visitor->HandleReverse(this);
2615     case HloOpcode::kReducePrecision:
2616       return visitor->HandleReducePrecision(this);
2617     case HloOpcode::kSlice:
2618       return visitor->HandleSlice(this);
2619     case HloOpcode::kDynamicSlice:
2620       return visitor->HandleDynamicSlice(this);
2621     case HloOpcode::kDynamicUpdateSlice:
2622       return visitor->HandleDynamicUpdateSlice(this);
2623     case HloOpcode::kSort:
2624       return visitor->HandleSort(this);
2625     case HloOpcode::kInfeed:
2626       return visitor->HandleInfeed(this);
2627     case HloOpcode::kOutfeed:
2628       return visitor->HandleOutfeed(this);
2629     case HloOpcode::kRng:
2630       return visitor->HandleRng(this);
2631     case HloOpcode::kWhile:
2632       return visitor->HandleWhile(this);
2633     case HloOpcode::kFusion:
2634       return visitor->HandleFusion(this);
2635     case HloOpcode::kCall:
2636       return visitor->HandleCall(this);
2637     case HloOpcode::kConditional:
2638       return visitor->HandleConditional(this);
2639     case HloOpcode::kCustomCall:
2640       return visitor->HandleCustomCall(this);
2641     case HloOpcode::kRecv:
2642       return visitor->HandleRecv(this);
2643     case HloOpcode::kRecvDone:
2644       return visitor->HandleRecvDone(this);
2645     case HloOpcode::kSend:
2646       return visitor->HandleSend(this);
2647     case HloOpcode::kSendDone:
2648       return visitor->HandleSendDone(this);
2649     case HloOpcode::kGather:
2650       return visitor->HandleGather(this);
2651     case HloOpcode::kScatter:
2652       return visitor->HandleScatter(this);
2653     case HloOpcode::kDomain:
2654       return visitor->HandleDomain(this);
2655     case HloOpcode::kAfterAll:
2656       return visitor->HandleAfterAll(this);
2657     case HloOpcode::kAddDependency:
2658       return visitor->HandleAddDependency(this);
2659     case HloOpcode::kIota:
2660       return visitor->HandleIota(this);
2661     case HloOpcode::kGetDimensionSize:
2662       return visitor->HandleGetDimensionSize(this);
2663     case HloOpcode::kTriangularSolve:
2664       return visitor->HandleTriangularSolve(this);
2665     case HloOpcode::kCholesky:
2666       return visitor->HandleCholesky(this);
2667 
2668     // These opcodes are not handled here.
2669     case HloOpcode::kTrace:
2670       break;
2671   }
2672   return InternalError(
2673       "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
2674       "please file a bug for XLA.",
2675       HloOpcodeString(opcode_));
2676 }
2677 
2678 // Explicit instantiations.
2679 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
2680 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
2681 
2682 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2683 
2684 // Push "child" onto the dfs_stack if not already visited.  Returns false if a
2685 // cycle was detected, and true otherwise.
2686 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)2687 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
2688                          HloInstruction* child) {
2689   CHECK(child != nullptr);
2690   const int id = child->unique_id();
2691   CHECK_GE(id, 0) << "instruction may not have a parent computation";
2692   switch (visitor->GetVisitState(id)) {
2693     case Visitor::kVisiting:
2694       return false;
2695 
2696     case Visitor::kVisited:
2697       // Nothing to do
2698       return true;
2699 
2700     case Visitor::kNotVisited:
2701       dfs_stack->push_back(std::make_pair(id, child));
2702       return true;
2703   }
2704 }
2705 
2706 using InternalCompareFunction =
2707     std::function<bool(std::pair<int, const HloInstruction*>,
2708                        std::pair<int, const HloInstruction*>)>;
2709 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)2710 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
2711                            const InternalCompareFunction* operand_order,
2712                            bool ignore_control_predecessors) {
2713   visitor->ReserveVisitStates(root->GetModule()->instruction_count());
2714 
2715   // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
2716   //
2717   // We need to keep track of both the id and the instruction because
2718   // instructions can get deleted while they are on the stack, so we
2719   // can't always use the (potentially dead) instruction object to grab
2720   // its id.
2721   DFSStack dfs_stack;
2722   dfs_stack.emplace_back(root->unique_id(), root);
2723 
2724   do {
2725     DCHECK(!dfs_stack.empty());
2726 
2727     int current_id = dfs_stack.back().first;
2728     HloInstruction* current_node = dfs_stack.back().second;
2729     CHECK_GE(current_id, 0) << current_id << ": " << current_node
2730                             << ": instruction may not have parent computation";
2731     typename Visitor::VisitState visit_state =
2732         visitor->GetVisitState(current_id);
2733     if (visit_state == Visitor::kVisited) {
2734       dfs_stack.pop_back();
2735       VLOG(3) << "Not visiting HLO %" << current_node->name()
2736               << " as it was already visited.";
2737       continue;
2738     }
2739 
2740     if (visit_state == Visitor::kVisiting) {
2741       dfs_stack.pop_back();
2742 
2743       TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
2744       VLOG(2) << "Visiting HLO %" << current_node->name();
2745       TF_RETURN_IF_ERROR(current_node->Visit(visitor));
2746       visitor->SetVisitState(current_id, Visitor::kVisited);
2747       TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
2748       continue;
2749     }
2750 
2751     visitor->SetVisitState(current_id, Visitor::kVisiting);
2752 
2753     const size_t old_dfs_stack_size = dfs_stack.size();
2754     for (HloInstruction* child : current_node->operands()) {
2755       if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
2756         return FailedPrecondition(
2757             "A cycle is detected while visiting instruction %s",
2758             current_node->ToString());
2759       }
2760     }
2761 
2762     if (!ignore_control_predecessors) {
2763       for (HloInstruction* child : current_node->control_predecessors()) {
2764         if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
2765           return FailedPrecondition(
2766               "A cycle is detected while visiting instruction %s",
2767               current_node->ToString());
2768         }
2769       }
2770     }
2771 
2772     if (operand_order != nullptr) {
2773       std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
2774                 *operand_order);
2775     }
2776 
2777     // This makes the traversal order the same as what you'd expect
2778     // out of a recursive algorithm.
2779     std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
2780   } while (!dfs_stack.empty());
2781 
2782   return Status::OK();
2783 }
2784 
2785 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)2786 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
2787                               bool call_finish_visit,
2788                               bool ignore_control_predecessors) {
2789   VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
2790   TF_RETURN_IF_ERROR(
2791       PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
2792   if (call_finish_visit) {
2793     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
2794   }
2795   return Status::OK();
2796 }
2797 
2798 // Explicit instantiations.
2799 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
2800 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
2801 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)2802 Status HloInstruction::AcceptWithOperandOrder(
2803     DfsHloVisitor* visitor, const CompareFunction& operand_order,
2804     bool call_finish_visit) {
2805   VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
2806   InternalCompareFunction func = [&operand_order](
2807                                      std::pair<int, const HloInstruction*> a,
2808                                      std::pair<int, const HloInstruction*> b) {
2809     // Call the client's comparison function on the actual HloInstruction*
2810     // objects (ignoring the internal ids we also have in our stack entries)
2811     return operand_order(a.second, b.second);
2812   };
2813   TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
2814                                   /*ignore_control_predecessors=*/false));
2815   if (call_finish_visit) {
2816     VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
2817     TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
2818     VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
2819   }
2820   VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
2821   return Status::OK();
2822 }
2823 
Accept(const std::function<Status (HloInstruction *)> & visitor_func)2824 Status HloInstruction::Accept(
2825     const std::function<Status(HloInstruction*)>& visitor_func) {
2826   FunctionVisitor visitor(visitor_func);
2827   return this->Accept(&visitor);
2828 }
2829 
Accept(const std::function<Status (const HloInstruction *)> & visitor_func) const2830 Status HloInstruction::Accept(
2831     const std::function<Status(const HloInstruction*)>& visitor_func) const {
2832   ConstFunctionVisitor visitor(visitor_func);
2833   return this->Accept(&visitor);
2834 }
2835 
shape() const2836 const Shape& HloInstruction::shape() const { return shape_; }
2837 
OperandIndices(const HloInstruction * operand) const2838 std::vector<int64> HloInstruction::OperandIndices(
2839     const HloInstruction* operand) const {
2840   std::vector<int64> result;
2841   for (int64 i = 0; i < operand_count(); ++i) {
2842     if (this->operand(i) == operand) {
2843       result.push_back(i);
2844     }
2845   }
2846   return result;
2847 }
2848 
IsElementwiseBinary() const2849 bool HloInstruction::IsElementwiseBinary() const {
2850   return IsElementwise() && operand_count() == 2;
2851 }
2852 
IsElementwise() const2853 bool HloInstruction::IsElementwise() const {
2854   return IsElementwiseImpl(absl::nullopt);
2855 }
2856 
IsElementwiseOnOperand(int64 operand_idx) const2857 bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
2858   return IsElementwiseImpl(operand_idx);
2859 }
2860 
2861 // A helper class for memoized, recursive computation of HloOpcode::kFusion
2862 // in HloInstruction::OperandElementUse below.
2863 class HloInstruction::FusionReusesParamElements {
2864  public:
2865   using UseKind = HloInstruction::UseKind;
2866 
2867   // We could rather iterate backwards through fused_instructions_ here, as it
2868   // is in reverse postorder, and compute whether each fused instruction reuses
2869   // the value of this parameter, which would save stack space but not allow us
2870   // to finish early if we find a reuse.
Compute(int64 i,const HloInstruction & hlo)2871   static UseKind Compute(int64 i, const HloInstruction& hlo) {
2872     absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
2873     return ComputeInternal(i, hlo, &memoization_cache);
2874   }
2875 
2876  private:
ComputeInternal(int64 i,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)2877   static UseKind ComputeInternal(
2878       int64 i, const HloInstruction& hlo,
2879       absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
2880     if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
2881       if (hlo_param->parameter_number() == i) {
2882         return UseKind::kUse;
2883       }
2884     }
2885 
2886     auto p = cache->emplace(&hlo, UseKind{});
2887     auto value_it = p.first;
2888     const bool key_is_new = p.second;
2889 
2890     if (key_is_new) {
2891       for (int64 j = 0; j < hlo.operands_.size(); ++j) {
2892         UseKind old_val = value_it->second;
2893 
2894         // The next operation invalidates iterators.
2895         UseKind new_val =
2896             Plus(old_val, std::min(hlo.OperandElementUse(j),
2897                                    ComputeInternal(i, *hlo.operand(j), cache)));
2898 
2899         // Re-acquire the iterator. We could work harder to do this only if
2900         // absolutely necessary, but this code is not hot enough to warrant
2901         // that.
2902         value_it = cache->find(&hlo);
2903         value_it->second = new_val;
2904       }
2905     }
2906     return value_it->second;
2907   }
2908 
2909   // Fold operation for UseKinds.
Plus(UseKind a,UseKind b)2910   static UseKind Plus(UseKind a, UseKind b) {
2911     if (a == UseKind::kNoUse) {
2912       return b;
2913     } else if (b == UseKind::kNoUse) {
2914       return a;
2915     } else if (a == UseKind::kReuse || b == UseKind::kReuse) {
2916       return UseKind::kReuse;
2917     } else if (a == UseKind::kUsePermutingElements ||
2918                b == UseKind::kUsePermutingElements) {
2919       return UseKind::kReuse;
2920     } else {
2921       CHECK(a == UseKind::kUse && b == UseKind::kUse);
2922       return UseKind::kUse;
2923     }
2924   }
2925 };
2926 
OperandElementUse(int64 i) const2927 HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
2928   switch (opcode_) {
2929     case HloOpcode::kBitcast:
2930     case HloOpcode::kConcatenate:
2931     case HloOpcode::kReshape:
2932     case HloOpcode::kReverse:
2933     case HloOpcode::kSlice:
2934     case HloOpcode::kTranspose:
2935       return UseKind::kUsePermutingElements;
2936     case HloOpcode::kPad:
2937       // Pad reuses the padding value but not the padded array elements.
2938       return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
2939     case HloOpcode::kReduce:
2940       // Reduce reuses the init values but not the operand array elements.
2941       return i >= Cast<HloReduceInstruction>(this)->input_count()
2942                  ? UseKind::kReuse
2943                  : UseKind::kUsePermutingElements;
2944     case HloOpcode::kFusion:
2945       // Uses the memoizing, recursive computation defined above.
2946       return FusionReusesParamElements::Compute(i, *fused_expression_root());
2947     case HloOpcode::kDot:
2948       // Dot operations with inputs [A,B] * [B,1] do not re-use
2949       // elements on their left operand.
2950       // Dot operations with inputs [1,A] * [A,B] do not re-use
2951       // elements on their right operand.
2952       if (shape().dimensions_size() == 2) {
2953         if ((i == 0 && shape().dimensions(1) == 1) ||
2954             (i == 1 && shape().dimensions(0) == 1)) {
2955           return UseKind::kUse;
2956         }
2957       }
2958       return UseKind::kReuse;
2959     case HloOpcode::kDynamicUpdateSlice:
2960       // Dynamic-update-slice reuses only start_indices.
2961       if (i == 0 || i == 1) {
2962         return UseKind::kUse;
2963       }
2964       return UseKind::kReuse;
2965     default:
2966       return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
2967   }
2968 }
2969 
2970 std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const2971 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
2972   if (HloOpcode::kReshape != opcode_) {
2973     return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
2974   }
2975   return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
2976                                                       shape_);
2977 }
2978 
ToString(HloInstruction::FusionKind kind)2979 string ToString(HloInstruction::FusionKind kind) {
2980   switch (kind) {
2981     case HloInstruction::FusionKind::kLoop:
2982       return "kLoop";
2983     case HloInstruction::FusionKind::kInput:
2984       return "kInput";
2985     case HloInstruction::FusionKind::kOutput:
2986       return "kOutput";
2987     case HloInstruction::FusionKind::kCustom:
2988       return "kCustom";
2989   }
2990 }
2991 
StringToFusionKind(const string & kind_name)2992 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
2993     const string& kind_name) {
2994   if (kind_name == "kLoop") {
2995     return HloInstruction::FusionKind::kLoop;
2996   }
2997   if (kind_name == "kInput") {
2998     return HloInstruction::FusionKind::kInput;
2999   }
3000   if (kind_name == "kOutput") {
3001     return HloInstruction::FusionKind::kOutput;
3002   }
3003   if (kind_name == "kCustom") {
3004     return HloInstruction::FusionKind::kCustom;
3005   }
3006   return InvalidArgument("Unknown fusion kind: %s", kind_name);
3007 }
3008 
PaddingConfigToString(const PaddingConfig & padding)3009 string PaddingConfigToString(const PaddingConfig& padding) {
3010   bool has_interior_padding =
3011       absl::c_any_of(padding.dimensions(),
3012                      [](const PaddingConfig::PaddingConfigDimension& dim) {
3013                        return dim.interior_padding() != 0;
3014                      });
3015   return StrJoin(
3016       padding.dimensions(), "x",
3017       [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
3018         StrAppend(
3019             out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
3020             has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
3021       });
3022 }
3023 
OpMetadataToString(const OpMetadata & metadata)3024 string OpMetadataToString(const OpMetadata& metadata) {
3025   std::vector<string> result;
3026   if (!metadata.op_type().empty()) {
3027     result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
3028   }
3029   if (!metadata.op_name().empty()) {
3030     result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\""));
3031   }
3032   if (!metadata.source_file().empty()) {
3033     result.push_back(
3034         StrCat("source_file=\"", CEscape(metadata.source_file()), "\""));
3035   }
3036   if (metadata.source_line() != 0) {
3037     result.push_back(StrCat("source_line=", metadata.source_line()));
3038   }
3039   return StrJoin(result, " ");
3040 }
3041 
RandomDistributionToString(const RandomDistribution & distribution)3042 string RandomDistributionToString(const RandomDistribution& distribution) {
3043   return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
3044 }
3045 
PrecisionToString(const PrecisionConfig::Precision & precision)3046 string PrecisionToString(const PrecisionConfig::Precision& precision) {
3047   return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
3048 }
3049 
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)3050 string ConvolutionDimensionNumbersToString(
3051     const ConvolutionDimensionNumbers& dnums) {
3052   // lhs_dims[i] is the symbol of the logical dimension i for the lhs
3053   // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
3054   std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
3055   lhs_dims[dnums.input_batch_dimension()] = 'b';
3056   lhs_dims[dnums.input_feature_dimension()] = 'f';
3057   for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
3058     lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
3059   }
3060 
3061   std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
3062   rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
3063   rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
3064   for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
3065     rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
3066   }
3067 
3068   std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
3069   output_dims[dnums.output_batch_dimension()] = 'b';
3070   output_dims[dnums.output_feature_dimension()] = 'f';
3071   for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
3072     output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
3073   }
3074 
3075   return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
3076                 StrJoin(output_dims, ""));
3077 }
3078 
StringToRandomDistribution(const string & name)3079 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
3080   static std::unordered_map<string, RandomDistribution>* map = [] {
3081     static auto* map = new std::unordered_map<string, RandomDistribution>;
3082     for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
3083       if (RandomDistribution_IsValid(i)) {
3084         auto value = static_cast<RandomDistribution>(i);
3085         (*map)[RandomDistributionToString(value)] = value;
3086       }
3087     }
3088     return map;
3089   }();
3090   auto found = map->find(absl::AsciiStrToLower(name));
3091   if (found == map->end()) {
3092     return InvalidArgument("Unknown distribution");
3093   }
3094   return found->second;
3095 }
3096 
StringToPrecision(const string & name)3097 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
3098   static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
3099     static auto* map =
3100         new std::unordered_map<string, PrecisionConfig::Precision>;
3101     for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
3102       if (PrecisionConfig::Precision_IsValid(i)) {
3103         auto value = static_cast<PrecisionConfig::Precision>(i);
3104         (*map)[PrecisionToString(value)] = value;
3105       }
3106     }
3107     return map;
3108   }();
3109   auto found = map->find(absl::AsciiStrToLower(name));
3110   if (found == map->end()) {
3111     return InvalidArgument("Unknown distribution");
3112   }
3113   return found->second;
3114 }
3115 
operator <<(std::ostream & os,HloInstruction::FusionKind kind)3116 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
3117   return os << ToString(kind);
3118 }
3119 
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const3120 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
3121                                   const HloInstruction* const& rhs) const {
3122   if (rhs == nullptr) {
3123     // Nothing compares less than nullptr.
3124     return false;
3125   }
3126   if (lhs == nullptr) {
3127     return true;
3128   }
3129   auto lhs_module = lhs->GetModule();
3130   auto rhs_module = rhs->GetModule();
3131   CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
3132         (lhs_module != nullptr && rhs_module != nullptr));
3133   if (lhs_module != nullptr &&
3134       lhs_module->unique_id() != rhs_module->unique_id()) {
3135     return lhs_module->unique_id() < rhs_module->unique_id();
3136   }
3137   return lhs->unique_id() < rhs->unique_id();
3138 }
3139 
CouldBeBitcast() const3140 bool HloInstruction::CouldBeBitcast() const {
3141   switch (opcode_) {
3142     case HloOpcode::kTranspose:
3143       return true;
3144     case HloOpcode::kReshape:
3145       return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
3146     default:
3147       return false;
3148   }
3149 }
3150 
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const3151 Status HloInstruction::GetBackendConfigInternal(
3152     tensorflow::protobuf::Message* proto) const {
3153   proto->Clear();
3154 
3155   // Empty string does not parse as valid JSON, but it's a valid backend config,
3156   // corresponding to the empty proto.
3157   if (backend_config_.empty()) {
3158     return Status::OK();
3159   }
3160   return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
3161 }
3162 
set_backend_config(const tensorflow::protobuf::Message & proto)3163 Status HloInstruction::set_backend_config(
3164     const tensorflow::protobuf::Message& proto) {
3165   TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
3166   return Status::OK();
3167 }
3168 
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)3169 /* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
3170     const tensorflow::protobuf::Message& proto) {
3171   string ret;
3172   TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret));
3173   return ret;
3174 }
3175 
precision_config() const3176 const PrecisionConfig& HloInstruction::precision_config() const {
3177   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3178     return convolution->precision_config();
3179   }
3180   if (auto* dot = DynCast<HloDotInstruction>(this)) {
3181     return dot->precision_config();
3182   }
3183   LOG(FATAL) << "Unimplemented method.";
3184 }
3185 
mutable_precision_config()3186 PrecisionConfig* HloInstruction::mutable_precision_config() {
3187   if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3188     return convolution->mutable_precision_config();
3189   }
3190   if (auto* dot = DynCast<HloDotInstruction>(this)) {
3191     return dot->mutable_precision_config();
3192   }
3193   LOG(FATAL) << "Unimplemented method.";
3194 }
3195 
GetModule() const3196 HloModule* HloInstruction::GetModule() const {
3197   if (parent_) {
3198     return parent_->parent();
3199   }
3200   return nullptr;
3201 }
3202 
UniquifyName(NameUniquer * name_uniquer)3203 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
3204   string parent_str = parent() == nullptr ? "noparent" : parent()->name();
3205   name_ = name_uniquer->GetUniqueName(name_);
3206 }
3207 
set_outer_dimension_partitions(const std::vector<int64> & outer_dimension_partitions)3208 void HloInstruction::set_outer_dimension_partitions(
3209     const std::vector<int64>& outer_dimension_partitions) {
3210   outer_dimension_partitions_ = outer_dimension_partitions;
3211 }
3212 
3213 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const3214 int64 HloInstruction::feature_index() const {
3215   return Cast<HloBatchNormInstruction>(this)->feature_index();
3216 }
3217 
epsilon() const3218 float HloInstruction::epsilon() const {
3219   return Cast<HloBatchNormInstruction>(this)->epsilon();
3220 }
3221 
fft_type() const3222 FftType HloInstruction::fft_type() const {
3223   return Cast<HloFftInstruction>(this)->fft_type();
3224 }
3225 
fft_length() const3226 const std::vector<int64>& HloInstruction::fft_length() const {
3227   return Cast<HloFftInstruction>(this)->fft_length();
3228 }
3229 
channel_id() const3230 int64 HloInstruction::channel_id() const {
3231   return Cast<HloSendRecvInstruction>(this)->channel_id();
3232 }
3233 
concatenate_dimension() const3234 int64 HloInstruction::concatenate_dimension() const {
3235   return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
3236 }
3237 
dimension() const3238 int64 HloInstruction::dimension() const {
3239   return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
3240 }
3241 
IsRank2Transpose() const3242 bool HloInstruction::IsRank2Transpose() const {
3243   auto transpose = DynCast<HloTransposeInstruction>(this);
3244   return transpose != nullptr && transpose->IsRank2Transpose();
3245 }
3246 
slice_starts(int64 dimension) const3247 int64 HloInstruction::slice_starts(int64 dimension) const {
3248   return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
3249 }
3250 
slice_starts() const3251 const std::vector<int64>& HloInstruction::slice_starts() const {
3252   return Cast<HloSliceInstruction>(this)->slice_starts();
3253 }
3254 
slice_limits(int64 dimension) const3255 int64 HloInstruction::slice_limits(int64 dimension) const {
3256   return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
3257 }
3258 
slice_limits() const3259 const std::vector<int64>& HloInstruction::slice_limits() const {
3260   return Cast<HloSliceInstruction>(this)->slice_limits();
3261 }
3262 
slice_strides(int64 dimension) const3263 int64 HloInstruction::slice_strides(int64 dimension) const {
3264   return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
3265 }
3266 
slice_strides() const3267 const std::vector<int64>& HloInstruction::slice_strides() const {
3268   return Cast<HloSliceInstruction>(this)->slice_strides();
3269 }
3270 
literal() const3271 const Literal& HloInstruction::literal() const {
3272   return Cast<HloConstantInstruction>(this)->literal();
3273 }
3274 
IsConstant() const3275 bool HloInstruction::IsConstant() const {
3276   return DynCast<HloConstantInstruction>(this) != nullptr;
3277 }
3278 
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)3279 void HloInstruction::RelayoutConstant(const Layout& new_layout,
3280                                       const ShapeIndex& shape_index) {
3281   Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
3282 }
3283 
TracingTag() const3284 string HloInstruction::TracingTag() const {
3285   return Cast<HloTraceInstruction>(this)->TracingTag();
3286 }
3287 
AddFusionOperand(HloInstruction * new_operand)3288 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
3289   return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
3290 }
3291 
3292 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)3293 void HloInstruction::MergeFusionInstruction(
3294     HloInstruction* instruction_to_merge) {
3295   return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
3296       Cast<HloFusionInstruction>(instruction_to_merge));
3297 }
3298 
3299 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)3300 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
3301     HloInstruction* instruction_to_merge) {
3302   return Cast<HloFusionInstruction>(this)
3303       ->MergeFusionInstructionIntoMultiOutput(
3304           Cast<HloFusionInstruction>(instruction_to_merge));
3305 }
3306 
FuseInstruction(HloInstruction * instruction_to_fuse)3307 HloInstruction* HloInstruction::FuseInstruction(
3308     HloInstruction* instruction_to_fuse) {
3309   return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
3310 }
3311 
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)3312 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
3313     HloInstruction* instruction_to_fuse) {
3314   return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
3315       instruction_to_fuse);
3316 }
3317 
fused_instructions_computation() const3318 HloComputation* HloInstruction::fused_instructions_computation() const {
3319   return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
3320 }
3321 
fused_expression_root() const3322 HloInstruction* HloInstruction::fused_expression_root() const {
3323   return Cast<HloFusionInstruction>(this)->fused_expression_root();
3324 }
3325 
3326 const tensorflow::gtl::iterator_range<UnwrappingIterator<
3327     std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const3328 HloInstruction::fused_instructions() const {
3329   return Cast<HloFusionInstruction>(this)->fused_instructions();
3330 }
3331 
3332 const tensorflow::gtl::iterator_range<
3333     UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()3334 HloInstruction::fused_instructions() {
3335   return Cast<HloFusionInstruction>(this)->fused_instructions();
3336 }
3337 
fused_instruction_count() const3338 int64 HloInstruction::fused_instruction_count() const {
3339   return Cast<HloFusionInstruction>(this)->fused_instruction_count();
3340 }
3341 
fused_parameter(int64 parameter_number) const3342 HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
3343   return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
3344 }
3345 
fused_parameters() const3346 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
3347   return Cast<HloFusionInstruction>(this)->fused_parameters();
3348 }
3349 
IsMultiOutputFusion() const3350 const bool HloInstruction::IsMultiOutputFusion() const {
3351   const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
3352   return fusion != nullptr && fusion->IsMultiOutputFusion();
3353 }
3354 
fusion_kind() const3355 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
3356   return Cast<HloFusionInstruction>(this)->fusion_kind();
3357 }
3358 
set_fusion_kind(FusionKind kind)3359 void HloInstruction::set_fusion_kind(FusionKind kind) {
3360   return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
3361 }
3362 
random_distribution() const3363 RandomDistribution HloInstruction::random_distribution() const {
3364   return Cast<HloRngInstruction>(this)->random_distribution();
3365 }
3366 
parameter_number() const3367 int64 HloInstruction::parameter_number() const {
3368   return Cast<HloParameterInstruction>(this)->parameter_number();
3369 }
3370 
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)3371 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
3372     absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
3373   return Cast<HloParameterInstruction>(this)
3374       ->set_parameter_replicated_at_leaf_buffers(
3375           parameter_replicated_at_leaf_buffers);
3376 }
3377 
3378 const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const3379 HloInstruction::parameter_replicated_at_leaf_buffers() const {
3380   return Cast<HloParameterInstruction>(this)
3381       ->parameter_replicated_at_leaf_buffers();
3382 }
3383 
tuple_index() const3384 int64 HloInstruction::tuple_index() const {
3385   return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
3386 }
3387 
exponent_bits() const3388 int32 HloInstruction::exponent_bits() const {
3389   return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
3390 }
3391 
mantissa_bits() const3392 int32 HloInstruction::mantissa_bits() const {
3393   return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
3394 }
3395 
infeed_config() const3396 string HloInstruction::infeed_config() const {
3397   return Cast<HloInfeedInstruction>(this)->infeed_config();
3398 }
3399 
set_infeed_config(const string & config)3400 void HloInstruction::set_infeed_config(const string& config) {
3401   return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
3402 }
3403 
outfeed_shape() const3404 const Shape& HloInstruction::outfeed_shape() const {
3405   return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
3406 }
3407 
outfeed_config() const3408 const string& HloInstruction::outfeed_config() const {
3409   return Cast<HloOutfeedInstruction>(this)->outfeed_config();
3410 }
3411 
replica_groups() const3412 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
3413   return Cast<HloCollectiveInstruction>(this)->replica_groups();
3414 }
3415 
3416 const std::vector<std::pair<int64, int64>>&
source_target_pairs() const3417 HloInstruction::source_target_pairs() const {
3418   return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
3419 }
3420 
all_reduce_barrier() const3421 string HloInstruction::all_reduce_barrier() const {
3422   return Cast<HloAllReduceInstruction>(this)->all_reduce_barrier();
3423 }
3424 
set_all_reduce_barrier(const string & barrier)3425 void HloInstruction::set_all_reduce_barrier(const string& barrier) {
3426   return Cast<HloAllReduceInstruction>(this)->set_all_reduce_barrier(barrier);
3427 }
3428 
all_reduce_id() const3429 absl::optional<int64> HloInstruction::all_reduce_id() const {
3430   return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
3431 }
3432 
set_all_reduce_id(const absl::optional<int64> & all_reduce_id)3433 void HloInstruction::set_all_reduce_id(
3434     const absl::optional<int64>& all_reduce_id) {
3435   return Cast<HloAllReduceInstruction>(this)->set_all_reduce_id(all_reduce_id);
3436 }
3437 
3438 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const3439 HloInstruction::convolution_dimension_numbers() const {
3440   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3441     return convolution->convolution_dimension_numbers();
3442   }
3443   if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
3444     return custom_call->convolution_dimension_numbers();
3445   }
3446   LOG(FATAL) << "Unimplemented method.";
3447 }
3448 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)3449 void HloInstruction::set_convolution_dimension_numbers(
3450     const ConvolutionDimensionNumbers& dnums) {
3451   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3452     convolution->set_convolution_dimension_numbers(dnums);
3453   } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
3454     custom_call->set_convolution_dimension_numbers(dnums);
3455   } else {
3456     LOG(FATAL) << "Unimplemented method.";
3457   }
3458 }
3459 
feature_group_count() const3460 int64 HloInstruction::feature_group_count() const {
3461   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3462     return convolution->feature_group_count();
3463   }
3464   return Cast<HloCustomCallInstruction>(this)->feature_group_count();
3465 }
3466 
set_feature_group_count(int64 feature_group_count)3467 void HloInstruction::set_feature_group_count(int64 feature_group_count) {
3468   Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
3469       feature_group_count);
3470 }
3471 
batch_group_count() const3472 int64 HloInstruction::batch_group_count() const {
3473   if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3474     return convolution->batch_group_count();
3475   }
3476   return Cast<HloCustomCallInstruction>(this)->batch_group_count();
3477 }
3478 
set_batch_group_count(int64 batch_group_count)3479 void HloInstruction::set_batch_group_count(int64 batch_group_count) {
3480   Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
3481       batch_group_count);
3482 }
3483 
select() const3484 HloComputation* HloInstruction::select() const {
3485   return Cast<HloSelectAndScatterInstruction>(this)->select();
3486 }
3487 
scatter() const3488 HloComputation* HloInstruction::scatter() const {
3489   return Cast<HloSelectAndScatterInstruction>(this)->scatter();
3490 }
3491 
set_select(HloComputation * computation)3492 void HloInstruction::set_select(HloComputation* computation) {
3493   return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
3494 }
3495 
set_scatter(HloComputation * computation)3496 void HloInstruction::set_scatter(HloComputation* computation) {
3497   return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
3498 }
3499 
custom_call_target() const3500 const string& HloInstruction::custom_call_target() const {
3501   return Cast<HloCustomCallInstruction>(this)->custom_call_target();
3502 }
3503 
padding_config() const3504 const PaddingConfig& HloInstruction::padding_config() const {
3505   return Cast<HloPadInstruction>(this)->padding_config();
3506 }
3507 
slice_sizes(int64 dimension) const3508 int64 HloInstruction::slice_sizes(int64 dimension) const {
3509   return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
3510 }
3511 
dynamic_slice_sizes() const3512 const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
3513   return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
3514 }
3515 
gather_dimension_numbers() const3516 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
3517   return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
3518 }
3519 
gather_slice_sizes() const3520 absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
3521   return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
3522 }
3523 
scatter_dimension_numbers() const3524 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
3525     const {
3526   return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
3527 }
3528 
dot_dimension_numbers() const3529 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
3530   return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
3531 }
3532 
operand_side_metadata() const3533 const DomainMetadata& HloInstruction::operand_side_metadata() const {
3534   return Cast<HloDomainInstruction>(this)->operand_side_metadata();
3535 }
3536 
user_side_metadata() const3537 const DomainMetadata& HloInstruction::user_side_metadata() const {
3538   return Cast<HloDomainInstruction>(this)->user_side_metadata();
3539 }
3540 
comparison_direction() const3541 ComparisonDirection HloInstruction::comparison_direction() const {
3542   return Cast<HloCompareInstruction>(this)->direction();
3543 }
3544 
triangular_solve_options() const3545 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
3546   return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
3547 }
3548 
cholesky_options() const3549 const CholeskyOptions& HloInstruction::cholesky_options() const {
3550   return Cast<HloCholeskyInstruction>(this)->cholesky_options();
3551 }
3552 
3553 }  // namespace xla
3554