1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17 
18 #include <set>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/comparison_util.h"
23 #include "tensorflow/compiler/xla/permutation_util.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 
36 namespace xla {
37 
IsCallerInstruction(HloInstruction * hlo)38 bool IsCallerInstruction(HloInstruction* hlo) {
39   switch (hlo->opcode()) {
40     case HloOpcode::kCall:
41     case HloOpcode::kConditional:
42     case HloOpcode::kWhile:
43     case HloOpcode::kAllReduce:
44     case HloOpcode::kMap:
45     case HloOpcode::kReduce:
46     case HloOpcode::kReduceWindow:
47     case HloOpcode::kScatter:
48     case HloOpcode::kSelectAndScatter:
49     case HloOpcode::kSort:
50     case HloOpcode::kFusion:
51     case HloOpcode::kCustomCall:
52       return true;
53     default:
54       return false;
55   }
56 }
57 
58 namespace {
59 
CheckOperandCount(const HloInstruction * hlo,int expected)60 Status CheckOperandCount(const HloInstruction* hlo, int expected) {
61   if (hlo->operand_count() != expected) {
62     return InternalError("Expected %d operands for %s instruction: %s",
63                          expected, HloOpcodeString(hlo->opcode()),
64                          hlo->ToString());
65   }
66   return Status::OK();
67 }
68 
CheckParameterCount(const HloInstruction * calling_instruction,const HloComputation * computation,int expected)69 Status CheckParameterCount(const HloInstruction* calling_instruction,
70                            const HloComputation* computation, int expected) {
71   if (computation->num_parameters() != expected) {
72     return InternalError(
73         "Expected computation %s called from %s to have %d parameters, has %d",
74         computation->name(), calling_instruction->name(), expected,
75         computation->num_parameters());
76   }
77   return Status::OK();
78 }
79 }  // namespace
80 
Preprocess(HloInstruction * hlo)81 Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
82   if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) {
83     return InternalError(
84         "Called computations specified for non-caller instruction  %s",
85         hlo->ToString());
86   }
87   absl::optional<int> arity = HloOpcodeArity(hlo->opcode());
88   if (arity) {
89     TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
90   }
91   return Status::OK();
92 }
93 
HandleElementwiseUnary(HloInstruction * hlo)94 Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
95   return CheckUnaryShape(hlo);
96 }
97 
HandleElementwiseBinary(HloInstruction * hlo)98 Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) {
99   return CheckBinaryShape(hlo);
100 }
101 
HandleClamp(HloInstruction * clamp)102 Status ShapeVerifier::HandleClamp(HloInstruction* clamp) {
103   return CheckTernaryShape(clamp);
104 }
105 
HandleSelect(HloInstruction * select)106 Status ShapeVerifier::HandleSelect(HloInstruction* select) {
107   return CheckTernaryShape(select);
108 }
109 
HandleTupleSelect(HloInstruction * tuple_select)110 Status ShapeVerifier::HandleTupleSelect(HloInstruction* tuple_select) {
111   return CheckTernaryShape(tuple_select);
112 }
113 
HandleConcatenate(HloInstruction * concatenate)114 Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
115   std::vector<const Shape*> operand_shapes;
116   for (const HloInstruction* operand : concatenate->operands()) {
117     operand_shapes.push_back(&operand->shape());
118   }
119   return CheckShape(concatenate,
120                     ShapeInference::InferConcatOpShape(
121                         operand_shapes, concatenate->concatenate_dimension()));
122 }
123 
HandleConvert(HloInstruction * convert)124 Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
125   return CheckShape(convert, ShapeInference::InferConvertShape(
126                                  convert->operand(0)->shape(),
127                                  convert->shape().element_type()));
128 }
129 
HandleBitcastConvert(HloInstruction * convert)130 Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
131   return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
132                                  convert->operand(0)->shape(),
133                                  convert->shape().element_type()));
134 }
135 
HandleCopy(HloInstruction * copy)136 Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
137   return CheckUnaryShape(copy);
138 }
139 
HandleDot(HloInstruction * dot)140 Status ShapeVerifier::HandleDot(HloInstruction* dot) {
141   TF_ASSIGN_OR_RETURN(
142       const Shape expected,
143       ShapeInference::InferDotOpShape(
144           dot->operand(0)->shape(), dot->operand(1)->shape(),
145           dot->dot_dimension_numbers(),
146           /*preferred_element_type=*/dot->shape().element_type()));
147   return CheckShape(dot, expected);
148 }
149 
HandleConvolution(HloInstruction * convolution)150 Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
151   TF_ASSIGN_OR_RETURN(
152       Shape expected,
153       ShapeInference::InferConvolveShape(
154           convolution->operand(0)->shape(), convolution->operand(1)->shape(),
155           convolution->feature_group_count(), convolution->batch_group_count(),
156           convolution->window(), convolution->convolution_dimension_numbers(),
157           /*preferred_element_type=*/convolution->shape().element_type()));
158   return CheckShape(convolution, expected);
159 }
160 
HandleFft(HloInstruction * fft)161 Status ShapeVerifier::HandleFft(HloInstruction* fft) {
162   TF_ASSIGN_OR_RETURN(
163       const Shape expected,
164       ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
165                                     fft->fft_length()));
166   return CheckShape(fft, expected);
167 }
168 
HandleTriangularSolve(HloInstruction * hlo)169 Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) {
170   TF_ASSIGN_OR_RETURN(const Shape expected,
171                       ShapeInference::InferTriangularSolveShape(
172                           hlo->operand(0)->shape(), hlo->operand(1)->shape(),
173                           hlo->triangular_solve_options()));
174   return CheckShape(hlo, expected);
175 }
176 
HandleCholesky(HloInstruction * hlo)177 Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
178   TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
179   TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape(
180                                                 hlo->operand(0)->shape()));
181   return CheckShape(hlo, expected);
182 }
183 
184 // Checks that `hlo`'s set of ReplicaGroups:
185 //
186 //  - names each replica 0 through n-1 exactly once, and
187 //  - does not contain any empty ReplicaGroups.
188 //
189 // Note that although none of the groups may be empty, `hlo` is allowed to have
190 // 0 groups.  That just means it has one big group.
191 //
192 // This is just a minimal set of checks; some instructions may have additional
193 // requirements.  For example, all-to-all requires that all ReplicaGroups have
194 // the same number of replicas, but that isn't checked here.
CheckReplicaGroups(HloInstruction * hlo,bool use_global_device_ids)195 static Status CheckReplicaGroups(HloInstruction* hlo,
196                                  bool use_global_device_ids) {
197   std::set<int64> replicas_seen;
198   for (const ReplicaGroup& g : hlo->replica_groups()) {
199     if (g.replica_ids().empty()) {
200       return InternalError("Instruction cannot have an empty replica group: %s",
201                            hlo->ToString());
202     }
203     for (int64 i : g.replica_ids()) {
204       if (!replicas_seen.insert(i).second) {
205         return InternalError(
206             "Replica %d is repeated in instruction's replica-groups: %s", i,
207             hlo->ToString());
208       }
209     }
210   }
211   for (int64 i = 0; i < replicas_seen.size(); ++i) {
212     if (!replicas_seen.count(i)) {
213       return InternalError(
214           "Replica %d is not named in instruction's replica-groups: %s", i,
215           hlo->ToString());
216     }
217   }
218 
219   // If use_global_device_ids() is set, replica_groups cannot be empty.
220   // When the channel_id() or use_global_device_ids() is set, device ids in
221   // ReplicaGroup config no longer only mean replica ids. So we skip the check
222   // on the replica count.
223   if (use_global_device_ids) {
224     if (hlo->replica_groups().empty()) {
225       return InternalError(
226           "Replica group must be specified when use_global_device_ids is true");
227     }
228     // No need to check replica_count.
229     return Status::OK();
230   }
231 
232   if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
233     if (channel_instr->channel_id()) {
234       return Status::OK();
235     }
236   }
237 
238   int64 replica_count = hlo->GetModule()->config().replica_count();
239   if (replica_count != 1 && !replicas_seen.empty() &&
240       replicas_seen.size() != replica_count) {
241     return InternalError(
242         "Replica count in HloModuleConfig is %d, but ReplicaGroup config "
243         "contains %d replicas: %s",
244         replica_count, replicas_seen.size(), hlo->ToString());
245   }
246 
247   return Status::OK();
248 }
249 
HandleAllGather(HloInstruction * hlo)250 Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
251   auto ag = Cast<HloAllGatherInstruction>(hlo);
252   TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, ag->use_global_device_ids()));
253   TF_RET_CHECK(ag->all_gather_dimension() >= 0);
254   TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
255   TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
256 
257   int64 shard_count = CeilOfRatio(
258       ag->shape().dimensions(ag->all_gather_dimension()),
259       ag->operand(0)->shape().dimensions(ag->all_gather_dimension()));
260   if (ag->channel_id().has_value()) {
261     if (ag->use_global_device_ids()) {
262       TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
263     } else {
264       if (ag->replica_groups().empty() ||
265           ag->replica_groups()[0].replica_ids_size() != 1) {
266         return InternalError(
267             "Replica group size must be 1 when use_global_device_ids is "
268             "false if the all-gather is also cross-partition");
269       }
270     }
271   } else if (!ag->replica_groups().empty()) {
272     // Cross-replica all-gather: shard count is subgroup size.
273     TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
274   }
275   return CheckShape(ag, ShapeInference::InferAllGatherShape(
276                             ag->operand(0)->shape(), ag->all_gather_dimension(),
277                             shard_count));
278 }
279 
HandleAllReduce(HloInstruction * hlo)280 Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
281   auto ar = Cast<HloAllReduceInstruction>(hlo);
282   TF_RETURN_IF_ERROR(CheckReplicaGroups(ar, ar->use_global_device_ids()));
283 
284   std::vector<const Shape*> operand_shapes;
285   for (const HloInstruction* operand : hlo->operands()) {
286     operand_shapes.push_back(&operand->shape());
287   }
288   return CheckShape(hlo, ShapeInference::InferAllReduceShape(operand_shapes));
289 }
290 
HandleAllToAll(HloInstruction * hlo)291 Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
292   auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
293   TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, /*use_global_device_ids=*/false));
294 
295   TF_RET_CHECK(all_to_all != nullptr);
296   if (all_to_all->split_dimension()) {
297     if (hlo->replica_groups().empty()) {
298       return InternalError(
299           "An array all-to-all must have an explicit replica_groups config");
300     }
301   }
302 
303   // The size of each replica group must be the same (the split count of the
304   // operaion). In case the default replica group is used (empty replica group,
305   // must not be an array all-to-all, as checked above), infer from the number
306   // of operands.
307   const int64 split_count = hlo->replica_groups().empty()
308                                 ? hlo->operand_count()
309                                 : hlo->replica_groups()[0].replica_ids_size();
310   for (const ReplicaGroup& g : hlo->replica_groups()) {
311     if (g.replica_ids_size() != split_count) {
312       return InternalError(
313           "Replica group has size %d, but all replica groups in an all-to-all "
314           "must have size N: %s",
315           g.replica_ids_size(), hlo->ToString());
316     }
317   }
318 
319   if (all_to_all->split_dimension()) {
320     TF_RET_CHECK(hlo->operand_count() == 1);
321     return CheckShape(
322         hlo, ShapeInference::InferAllToAllShape(
323                  hlo->operand(0)->shape(), *all_to_all->split_dimension(),
324                  *all_to_all->split_dimension(), split_count));
325   } else {
326     std::vector<const Shape*> operand_shapes;
327     for (const HloInstruction* operand : hlo->operands()) {
328       operand_shapes.push_back(&operand->shape());
329     }
330     return CheckShape(hlo,
331                       ShapeInference::InferAllToAllTupleShape(operand_shapes));
332   }
333 }
334 
HandlePartitionId(HloInstruction * hlo)335 Status ShapeVerifier::HandlePartitionId(HloInstruction* hlo) {
336   return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
337 }
338 
HandleReplicaId(HloInstruction * hlo)339 Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
340   return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
341 }
342 
343 namespace {
344 
CheckDuplicatedSourceOrTarget(HloInstruction * hlo)345 Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) {
346   // A source or target cannot appear twice in the collective-permute's
347   // source-target pairs.
348   absl::flat_hash_set<int64> seen_sources;
349   absl::flat_hash_set<int64> seen_targets;
350   for (const auto& p : hlo->source_target_pairs()) {
351     if (!seen_sources.insert(p.first).second) {
352       return InternalError(
353           "Source %d appears more than once in instruction's source-target "
354           "pairs: %s",
355           p.first, hlo->ToString());
356     }
357     if (!seen_targets.insert(p.second).second) {
358       return InternalError(
359           "Target %d appears more than once in instruction's source-target "
360           "pairs: %s",
361           p.second, hlo->ToString());
362     }
363   }
364   return Status::OK();
365 }
366 
367 }  // namespace
368 
HandleCollectivePermute(HloInstruction * hlo)369 Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
370   TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
371   return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
372                              hlo->operand(0)->shape()));
373 }
374 
HandleCollectivePermuteStart(HloInstruction * hlo)375 Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) {
376   TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
377   return CheckShape(
378       hlo, ShapeUtil::MakeTupleShape(
379                {hlo->operand(0)->shape(), hlo->operand(0)->shape(),
380                 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}));
381 }
382 
HandleCollectivePermuteDone(HloInstruction * hlo)383 Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) {
384   return CheckShape(
385       hlo, ShapeUtil::GetTupleElementShape(hlo->operand(0)->shape(), 0));
386 }
387 
HandleReducePrecision(HloInstruction * reduce_precision)388 Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
389   return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
390                                           reduce_precision->operand(0)->shape(),
391                                           reduce_precision->exponent_bits(),
392                                           reduce_precision->mantissa_bits()));
393 }
394 
CheckIsTokenOperand(const HloInstruction * instruction,int64 operand_no)395 Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
396                                           int64 operand_no) {
397   const HloInstruction* token = instruction->operand(operand_no);
398   if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
399     return InternalError(
400         "Expected operand %d to be token-shaped, actual shape is "
401         "%s:\n%s",
402         operand_no, StringifyShape(token->shape()), instruction->ToString());
403   }
404   return Status::OK();
405 }
406 
CheckOperandAndParameter(const HloInstruction * instruction,int64 operand_number,const HloComputation * computation,int64 parameter_number)407 Status ShapeVerifier::CheckOperandAndParameter(
408     const HloInstruction* instruction, int64 operand_number,
409     const HloComputation* computation, int64 parameter_number) {
410   const HloInstruction* operand = instruction->operand(operand_number);
411   const HloInstruction* parameter =
412       computation->parameter_instruction(parameter_number);
413   if (!ShapesSame(operand->shape(), parameter->shape())) {
414     return InternalError("Operand %s shape does not match parameter's %s in %s",
415                          operand->ToString(), parameter->ToString(),
416                          instruction->ToString());
417   }
418   return Status::OK();
419 }
420 
HandleInfeed(HloInstruction * instruction)421 Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
422   HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
423   TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
424 
425   // The output of infeed is a tuple containing the data value and a token.
426   return CheckShape(infeed,
427                     ShapeUtil::MakeTupleShape(
428                         {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()}));
429 }
430 
HandleOutfeed(HloInstruction * instruction)431 Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
432   HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
433   TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
434 
435   // Outfeed has a separate shape field for the value which is outfed to the
436   // host. The shape of the instruction itself is always a token.
437   if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
438     return InternalError(
439         "Expected outfeed shape to be equal to operand's shape %s, "
440         "actual shape is %s:\n%s",
441         StringifyShape(outfeed->operand(0)->shape()),
442         StringifyShape(outfeed->outfeed_shape()), outfeed->ToString());
443   }
444   return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
445 }
446 
HasCompatibleElementTypes(const Shape & shape_0,const Shape & shape_1,const Shape & result_shape)447 bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
448                                               const Shape& shape_1,
449                                               const Shape& result_shape) {
450   return ShapeUtil::SameElementType(shape_0, shape_1) &&
451          (ShapeUtil::SameElementType(shape_0, result_shape) ||
452           (allow_mixed_precision_ &&
453            ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0,
454                                                          result_shape)));
455 }
456 
HandleRng(HloInstruction * instruction)457 Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
458   TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
459 
460   const Shape& shape_0 = instruction->operand(0)->shape();
461   const Shape& shape_1 = instruction->operand(1)->shape();
462   if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
463     return InternalError(
464         "Expected scalar types for the two operands of Rng instruction: %s",
465         instruction->ToString());
466   }
467 
468   if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
469     return InternalError(
470         "Expected compatible element types for the result and the two operands"
471         " of Rng instruction: %s",
472         instruction->ToString());
473   }
474 
475   PrimitiveType element_type = shape_0.element_type();
476   switch (instruction->random_distribution()) {
477     case RNG_UNIFORM:
478       if (!primitive_util::IsFloatingPointType(element_type) &&
479           !primitive_util::IsIntegralType(element_type) &&
480           element_type != PRED) {
481         return InternalError(
482             "Element type not supported."
483             " Expected element to be of floating point type, integral type or"
484             " predicate type for RngUniform: %s",
485             instruction->ToString());
486       }
487       break;
488 
489     case RNG_NORMAL:
490       if (!primitive_util::IsFloatingPointType(element_type)) {
491         return InternalError(
492             "Element type not supported."
493             " Expected element to be FloatingPointType for RngNormal: %s",
494             instruction->ToString());
495       }
496       break;
497     default:
498       return InternalError(
499           "Invalid Rng distribution %s",
500           RandomDistribution_Name(instruction->random_distribution()));
501   }
502 
503   return Status::OK();
504 }
505 
HandleRngBitGenerator(HloInstruction * hlo)506 Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) {
507   if (!hlo->shape().IsTuple() || hlo->shape().tuple_shapes_size() != 2) {
508     return InternalError(
509         "Expected tuple shape with 2 elements for RngBitGenerator. Got: %s",
510         hlo->shape().ToString());
511   }
512   if (!ShapeUtil::Compatible(hlo->operand(0)->shape(),
513                              hlo->shape().tuple_shapes(0))) {
514     return InternalError(
515         "Expected state shape to match between input and output for "
516         "RngBitGenerator. Got %s vs. %s",
517         hlo->operand(0)->shape().ToString(),
518         hlo->shape().tuple_shapes(0).ToString());
519   }
520   return Status::OK();
521 }
522 
HandleRngGetAndUpdateState(HloInstruction * instruction)523 Status ShapeVerifier::HandleRngGetAndUpdateState(HloInstruction* instruction) {
524   TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
525   const Shape& result_shape = instruction->shape();
526   const Shape expected_shape = ShapeUtil::MakeShape(U64, {2});
527   if (!ShapeUtil::Compatible(result_shape, expected_shape)) {
528     return InternalError(
529         "Invalid RngGetAndUpdateState, expect result to have shape %s, got %s ",
530         StringifyShape(expected_shape), StringifyShape(result_shape));
531   }
532 
533   return Status::OK();
534 }
535 
HandleReverse(HloInstruction * reverse)536 Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
537   return CheckShape(
538       reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
539                                                  reverse->dimensions()));
540 }
541 
HandleSort(HloInstruction * sort)542 Status ShapeVerifier::HandleSort(HloInstruction* sort) {
543   if (sort->operand_count() < 1) {
544     return InternalError("Expected at least 1 operand for %s instruction: %s",
545                          HloOpcodeString(sort->opcode()), sort->ToString());
546   }
547   HloComputation* compare = sort->to_apply();
548 
549   // Check that the 'compare' computation returns a PRED.
550   Shape compare_shape = compare->root_instruction()->shape();
551   if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
552     return InternalError(
553         "The Sort compare computation shape does not lead to a scalar "
554         "predicate shape: %s",
555         StringifyShape(compare_shape));
556   }
557 
558   // Check that the number of parameters of the 'compare' computation is
559   // correct.
560   TF_RETURN_IF_ERROR(
561       CheckParameterCount(sort, compare, sort->operand_count() * 2));
562 
563   // Verify that the operands of the compare computation have the correct scalar
564   // shapes.
565   for (int64 parameter_idx = 0; parameter_idx < compare->num_parameters();
566        ++parameter_idx) {
567     int64 operand_idx = parameter_idx / 2;
568     Shape expected_scalar_shape = ShapeUtil::MakeShape(
569         sort->operand(operand_idx)->shape().element_type(), {});
570     Shape actual_parameter_shape =
571         compare->parameter_instruction(parameter_idx)->shape();
572     if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape,
573                                                   actual_parameter_shape)) {
574       return InternalError(
575           "Expected the %lld-th parameter of the compare computation of sort "
576           "to have shape %s, but got %s",
577           parameter_idx, StringifyShape(expected_scalar_shape),
578           StringifyShape(actual_parameter_shape));
579     }
580   }
581 
582   // Verify that all operand shapes have the same dimensions.
583   for (int64 operand = 1; operand < sort->operand_count(); ++operand) {
584     if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
585                                    sort->operand(operand)->shape())) {
586       return InternalError(
587           "Expected sort to have to have the same dimensions for all operands. "
588           "First operand shape is: %s\n, shape (operand index %lld) is: %s",
589           StringifyShape(sort->operand(0)->shape()), operand,
590           StringifyShape(sort->operand(operand)->shape()));
591     }
592   }
593   return CheckVariadicShape(sort);
594 }
595 
HandleConstant(HloInstruction * constant)596 Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
597   if (!Cast<HloConstantInstruction>(constant)->HasLiteral()) {
598     return InternalError("Constant is required to have a valid literal: %s",
599                          constant->ToString());
600   }
601   return CheckShape(constant, constant->literal().shape(),
602                     /*only_compare_minor_to_major_in_layout=*/true);
603 }
604 
HandleIota(HloInstruction * hlo)605 Status ShapeVerifier::HandleIota(HloInstruction* hlo) {
606   auto* iota = Cast<HloIotaInstruction>(hlo);
607   if (!iota->shape().IsArray()) {
608     return InternalError("Iota does not support non-array result.");
609   }
610   const int64 rank = iota->shape().rank();
611   if (rank == 0) {
612     return InternalError("Iota does not support scalars.");
613   }
614   int64 iota_dimension = iota->iota_dimension();
615   if (iota_dimension >= rank || iota_dimension < 0) {
616     return InternalError(
617         "The iota dimension cannot go beyond the operation rank or be "
618         "negative.");
619   }
620 
621   PrimitiveType primitive_type = iota->shape().element_type();
622   if (!primitive_util::IsIntegralType(primitive_type) &&
623       !primitive_util::IsFloatingPointType(primitive_type) &&
624       !primitive_util::IsComplexType(primitive_type)) {
625     return InvalidArgument(
626         "Only support iota of integral, floating point or complex primitive "
627         "types, got %s",
628         PrimitiveType_Name(primitive_type));
629   }
630 
631   return Status::OK();
632 }
633 
HandleGetTupleElement(HloInstruction * get_tuple_element)634 Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
635   return CheckShape(get_tuple_element,
636                     ShapeInference::InferGetTupleElementShape(
637                         get_tuple_element->operand(0)->shape(),
638                         get_tuple_element->tuple_index()));
639 }
640 
641 namespace {
SameElementTypesForOperandsAndToApplyParameters(const HloInstruction & instruction,int64 num_operands_to_check)642 Status SameElementTypesForOperandsAndToApplyParameters(
643     const HloInstruction& instruction, int64 num_operands_to_check) {
644   const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape();
645   for (int i = 0; i < num_operands_to_check; ++i) {
646     const Shape& parameter_shape = to_apply.parameters(i);
647     const Shape& operand_shape = instruction.operands()[i]->shape();
648     if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) {
649       return InvalidArgument(
650           "Shape mismatch between to_apply computation"
651           " parameter and operand %d in %s.",
652           i, instruction.ToString().c_str());
653     }
654   }
655   return Status::OK();
656 }
657 }  // namespace
658 
HandleReduce(HloInstruction * reduce)659 Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
660   if (reduce->operand_count() % 2 != 0) {
661     return InternalError(
662         "Expected an even number of operands for %s instruction: %s",
663         HloOpcodeString(reduce->opcode()), reduce->ToString());
664   }
665 
666   std::vector<const Shape*> operand_shapes;
667   for (const HloInstruction* operand : reduce->operands()) {
668     operand_shapes.push_back(&operand->shape());
669   }
670   TF_RETURN_IF_ERROR(
671       CheckShape(reduce, ShapeInference::InferReduceShape(
672                              operand_shapes, reduce->dimensions(),
673                              reduce->to_apply()->ComputeProgramShape())));
674 
675   return allow_mixed_precision_
676              ? Status::OK()
677              : SameElementTypesForOperandsAndToApplyParameters(
678                    *reduce, reduce->operands().size() - 1);
679 }
680 
HandleBitcast(HloInstruction * bitcast)681 Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
682   if (layout_sensitive_ &&
683       shape_size_function_(bitcast->shape()) !=
684           shape_size_function_(bitcast->operand(0)->shape())) {
685     return InternalError(
686         "Bitcast cannot have different shape sizes of output (%d) and operand "
687         "(%d) (%s) (%s)",
688         shape_size_function_(bitcast->shape()),
689         shape_size_function_(bitcast->operand(0)->shape()),
690         bitcast->shape().ToString(true),
691         bitcast->operand(0)->shape().ToString(true));
692   }
693   return Status::OK();
694 }
695 
HandleBroadcast(HloInstruction * broadcast)696 Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
697   // HLO broadcast has no exact analog at the proto level so there is no
698   // ShapeInference method. Check the output shape explicitly.
699   const Shape& operand_shape = broadcast->operand(0)->shape();
700   // Check for mixed precision.
701   TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape));
702   TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size());
703   for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank();
704        ++operand_dimension) {
705     int64 output_dimension = broadcast->dimensions()[operand_dimension];
706     TF_RET_CHECK((output_dimension < broadcast->shape().rank()) &&
707                  output_dimension >= 0 &&
708                  (broadcast->shape().dimensions(output_dimension) ==
709                   operand_shape.dimensions(operand_dimension)))
710         << broadcast->ToString() << " operand shape " << operand_shape;
711   }
712   return Status::OK();
713 }
714 
HandleDynamicReshape(HloInstruction * dynamic_reshape)715 Status ShapeVerifier::HandleDynamicReshape(HloInstruction* dynamic_reshape) {
716   // Check for mixed precision.
717   const Shape& operand_shape = dynamic_reshape->operand(0)->shape();
718   TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape));
719   TF_RET_CHECK(ShapeUtil::ElementsIn(dynamic_reshape->shape()) ==
720                ShapeUtil::ElementsIn(operand_shape));
721   TF_RET_CHECK(dynamic_reshape->shape().rank() + 1 ==
722                dynamic_reshape->operand_count());
723   for (int64 i = 1; i < dynamic_reshape->operand_count(); ++i) {
724     TF_RET_CHECK(dynamic_reshape->operand(i)->shape().element_type() == S32);
725   }
726   return Status::OK();
727 }
728 
HandleReshape(HloInstruction * reshape)729 Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
730   // Check for mixed precision.
731   const Shape& operand_shape = reshape->operand(0)->shape();
732   TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape));
733   TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
734                ShapeUtil::ElementsIn(operand_shape));
735   return Status::OK();
736 }
737 
HandleTranspose(HloInstruction * transpose)738 Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
739   return CheckShape(
740       transpose, ShapeInference::InferTransposeShape(
741                      transpose->operand(0)->shape(), transpose->dimensions()));
742 }
743 
HandleParameter(HloInstruction * hlo)744 Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
745   return Status::OK();
746 }
747 
HandleFusion(HloInstruction * fusion)748 Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
749   if (fusion->called_computations().size() != 1) {
750     return InternalError(
751         "Fusion has a non-unary number of called computations (%s)",
752         fusion->ToString().c_str());
753   }
754   const Shape& root_computation_shape =
755       fusion->called_computations()[0]->root_instruction()->shape();
756   if (!ShapesSame(fusion->shape(), root_computation_shape)) {
757     return InternalError(
758         "Fused computation shape (%s) is not equal to the fusion shape (%s)",
759         root_computation_shape.ToString(true), fusion->shape().ToString(true));
760   }
761 
762   auto& fused_parameters = fusion->fused_parameters();
763   if (fused_parameters.size() != fusion->operand_count()) {
764     return InternalError(
765         "Fused parameter count (%d) does not match the number of operands (%d)"
766         " passed to the fusion instruction in: %s.",
767         fused_parameters.size(), fusion->operand_count(),
768         fusion->ToString().c_str());
769   }
770   for (HloInstruction* fused_param : fused_parameters) {
771     int64 param_no = fused_param->parameter_number();
772     if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
773       return InternalError(
774           "Shape mismatch between parameter number %d and its operand in "
775           "%s.",
776           param_no, fusion->ToString().c_str());
777     }
778   }
779   return Status::OK();
780 }
781 
HandleCall(HloInstruction * call)782 Status ShapeVerifier::HandleCall(HloInstruction* call) {
783   TF_RETURN_IF_ERROR(
784       CheckParameterCount(call, call->to_apply(), call->operand_count()));
785   for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
786     TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
787   }
788   // The shape of kCall should match the shape of the computation it calls.
789   return CheckShape(call, call->to_apply()->root_instruction()->shape());
790 }
791 
HandleCustomCall(HloInstruction * instruction)792 Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
793   const HloCustomCallInstruction* custom_call =
794       DynCast<const HloCustomCallInstruction>(instruction);
795   TF_RET_CHECK(custom_call != nullptr);
796   if (custom_call->layout_constrained()) {
797     // If the layout is constrained, verify all the respective shapes have
798     // layouts and that the constrained operand shapes match the shapes of the
799     // operands.
800     TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape()));
801     TF_RET_CHECK(custom_call->operand_count() ==
802                  custom_call->operand_shapes_with_layout().size());
803     for (int64 i = 0; i < custom_call->operand_count(); ++i) {
804       const Shape& operand_shape_with_layout =
805           custom_call->operand_shapes_with_layout()[i];
806       TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
807                                          operand_shape_with_layout))
808           << custom_call->operand(i)->shape().ToString() << " operand "
809           << operand_shape_with_layout.ToString();
810       TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
811     }
812   }
813   for (const auto& pair : custom_call->output_to_operand_aliasing()) {
814     TF_RET_CHECK(pair.second.first < custom_call->operand_count())
815         << "Invalid aliasing operand index.";
816     TF_RET_CHECK(ShapeUtil::IndexIsValid(
817         custom_call->operand(pair.second.first)->shape(), pair.second.second))
818         << "Invalid aliasing operand shape index.";
819     TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first))
820         << "Invalid aliasing output shape index.";
821     const Shape& output_subshape =
822         ShapeUtil::GetSubshape(custom_call->shape(), pair.first);
823     const Shape& operand_subshape = ShapeUtil::GetSubshape(
824         custom_call->operand(pair.second.first)->shape(), pair.second.second);
825     if (layout_sensitive_) {
826       TF_RET_CHECK(operand_subshape == output_subshape)
827           << "Different aliasing shapes: " << operand_subshape.ToString()
828           << " vs " << output_subshape.ToString();
829     } else {
830       TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape))
831           << "Different aliasing shapes: " << operand_subshape.ToString()
832           << " vs " << output_subshape.ToString();
833     }
834   }
835   return Status::OK();
836 }
837 
HandleSlice(HloInstruction * slice)838 Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
839   return CheckShape(slice,
840                     ShapeInference::InferSliceShape(
841                         slice->operand(0)->shape(), slice->slice_starts(),
842                         slice->slice_limits(), slice->slice_strides()));
843 }
844 
HandleDynamicSlice(HloInstruction * dynamic_slice)845 Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
846   return CheckShape(
847       dynamic_slice,
848       ShapeInference::InferDynamicSliceShape(
849           dynamic_slice->operand(0)->shape(),
850           Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
851           dynamic_slice->dynamic_slice_sizes()));
852 }
853 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)854 Status ShapeVerifier::HandleDynamicUpdateSlice(
855     HloInstruction* dynamic_update_slice) {
856   return CheckShape(
857       dynamic_update_slice,
858       ShapeInference::InferDynamicUpdateSliceShape(
859           dynamic_update_slice->operand(0)->shape(),
860           dynamic_update_slice->operand(1)->shape(),
861           Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
862               ->index_shapes()));
863 }
864 
HandleTuple(HloInstruction * tuple)865 Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
866   return CheckVariadicShape(tuple);
867 }
868 
HandleMap(HloInstruction * map)869 Status ShapeVerifier::HandleMap(HloInstruction* map) {
870   std::vector<const Shape*> operand_shapes;
871   int64 max_operand_rank = 0;
872   for (const HloInstruction* operand : map->operands()) {
873     operand_shapes.push_back(&operand->shape());
874     max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
875   }
876   // TODO(b/65689298) Remove code below once Map is generalized to accept
877   // arbitrary map dimensions.
878   std::vector<int64> map_dims(max_operand_rank);
879   std::iota(map_dims.begin(), map_dims.end(), 0);
880 
881   TF_RETURN_IF_ERROR(CheckShape(
882       map,
883       ShapeInference::InferMapShape(
884           operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)));
885 
886   return allow_mixed_precision_
887              ? Status::OK()
888              : SameElementTypesForOperandsAndToApplyParameters(
889                    *map, map->operands().size());
890 }
891 
HandleReduceWindow(HloInstruction * reduce_window)892 Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
893   VLOG(2) << "Verify reduce window:" << reduce_window->ToString() << "\n";
894   auto reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window);
895   auto input_shapes = reduce_window_instr->input_array_shapes();
896   VLOG(2) << "reduce window input shape count: " << input_shapes.size() << "\n";
897   auto init_shapes = reduce_window_instr->init_value_shapes();
898   VLOG(2) << "reduce instruction is :" << reduce_window->ToString() << "\n";
899   TF_RETURN_IF_ERROR(CheckShape(
900       reduce_window, ShapeInference::InferReduceWindowShape(
901                          input_shapes, init_shapes, reduce_window->window(),
902                          reduce_window->to_apply()->ComputeProgramShape())));
903 
904   return allow_mixed_precision_
905              ? Status::OK()
906              : SameElementTypesForOperandsAndToApplyParameters(*reduce_window,
907                                                                1);
908 }
909 
HandleSelectAndScatter(HloInstruction * instruction)910 Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
911   return CheckShape(
912       instruction,
913       ShapeInference::InferSelectAndScatterShape(
914           instruction->operand(0)->shape(),
915           instruction->select()->ComputeProgramShape(), instruction->window(),
916           instruction->operand(1)->shape(), instruction->operand(2)->shape(),
917           instruction->scatter()->ComputeProgramShape()));
918 }
919 
HandleWhile(HloInstruction * xla_while)920 Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
921   TF_RETURN_IF_ERROR(
922       CheckParameterCount(xla_while, xla_while->while_body(), 1));
923   TF_RETURN_IF_ERROR(
924       CheckParameterCount(xla_while, xla_while->while_condition(), 1));
925   TF_RETURN_IF_ERROR(
926       CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
927   TF_RETURN_IF_ERROR(
928       CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
929   const Shape& conditional_shape =
930       xla_while->while_condition()->root_instruction()->shape();
931   if (!ShapeUtil::Compatible(conditional_shape,
932                              ShapeUtil::MakeShape(PRED, {}))) {
933     return InternalError(
934         "Conditional computation shape does not lead to a scalar predicate "
935         "shape: %s",
936         StringifyShape(conditional_shape));
937   }
938   // The shape of kWhile should match the shape of the body computation it
939   // calls.
940   return CheckShape(xla_while,
941                     xla_while->while_body()->root_instruction()->shape());
942 }
943 
HandleConditional(HloInstruction * conditional)944 Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
945   if (!ShapeUtil::IsScalar(conditional->operand(0)->shape())) {
946     return InvalidArgument(
947         "The first operand of conditional must be a scalar. Got %s",
948         conditional->operand(0)->shape().DebugString());
949   }
950   const int num_branches = conditional->branch_count();
951   PrimitiveType operand0_type = conditional->operand(0)->shape().element_type();
952   if (operand0_type == PRED) {
953     TF_RET_CHECK(num_branches == 2);
954   } else {
955     if (operand0_type != S32) {
956       return InvalidArgument(
957           "The first operand of indexed conditional must be a scalar of S32. "
958           "Got"
959           " type %s.",
960           PrimitiveType_Name(operand0_type));
961     }
962     TF_RET_CHECK(num_branches >= 1);
963   }
964   TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1));
965   for (int j = 0; j < num_branches; ++j) {
966     TF_RETURN_IF_ERROR(CheckParameterCount(
967         conditional, conditional->branch_computation(j), 1));
968     TF_RETURN_IF_ERROR(CheckOperandAndParameter(
969         conditional, j + 1, conditional->branch_computation(j), 0));
970     TF_RETURN_IF_ERROR(CheckShape(
971         conditional,
972         conditional->branch_computation(j)->root_instruction()->shape()));
973   }
974   return Status::OK();
975 }
976 
HandlePad(HloInstruction * pad)977 Status ShapeVerifier::HandlePad(HloInstruction* pad) {
978   return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
979                                                        pad->operand(1)->shape(),
980                                                        pad->padding_config()));
981 }
982 
HandleCopyStart(HloInstruction * copy_start)983 Status ShapeVerifier::HandleCopyStart(HloInstruction* copy_start) {
984   return CheckShape(copy_start,
985                     ShapeUtil::MakeTupleShape({copy_start->operand(0)->shape(),
986                                                copy_start->operand(0)->shape(),
987                                                ShapeUtil::MakeShape(U32, {})}),
988                     /*only_compare_minor_to_major_in_layout=*/true);
989 }
990 
HandleCopyDone(HloInstruction * copy_done)991 Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) {
992   const Shape& operand_shape = copy_done->operand(0)->shape();
993   const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0);
994   const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1);
995   if (!ShapesSame(dest_shape, src_shape,
996                   /*minor_to_major_only=*/false,
997                   /*ignore_memory_space=*/true)) {
998     return InternalError(
999         "Source and destination buffers in CopyDone arguments need to be the "
1000         "same shape found %s and %s\n%s",
1001         StringifyShape(dest_shape), StringifyShape(src_shape),
1002         copy_done->ToString());
1003   }
1004   return CheckShape(copy_done, ShapeUtil::GetTupleElementShape(
1005                                    copy_done->operand(0)->shape(), 0));
1006 }
1007 
HandleSend(HloInstruction * send)1008 Status ShapeVerifier::HandleSend(HloInstruction* send) {
1009   return CheckShape(send,
1010                     ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
1011                                                ShapeUtil::MakeShape(U32, {}),
1012                                                ShapeUtil::MakeTokenShape()}),
1013                     /*only_compare_minor_to_major_in_layout=*/true);
1014 }
1015 
HandleSendDone(HloInstruction * send_done)1016 Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
1017   return CheckShape(send_done, ShapeUtil::MakeTokenShape());
1018 }
1019 
HandleRecv(HloInstruction * recv)1020 Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
1021   return CheckShape(
1022       recv,
1023       ShapeUtil::MakeTupleShape(
1024           {ShapeUtil::GetTupleElementShape(recv->shape(), 0),
1025            ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}),
1026       /*only_compare_minor_to_major_in_layout=*/true);
1027 }
1028 
HandleRecvDone(HloInstruction * recv_done)1029 Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
1030   return CheckShape(
1031       recv_done,
1032       ShapeUtil::MakeTupleShape(
1033           {ShapeUtil::GetTupleElementShape(recv_done->operand(0)->shape(), 0),
1034            ShapeUtil::MakeTokenShape()}));
1035 }
1036 
HandleBatchNormTraining(HloInstruction * batch_norm_training)1037 Status ShapeVerifier::HandleBatchNormTraining(
1038     HloInstruction* batch_norm_training) {
1039   return CheckShape(batch_norm_training,
1040                     ShapeInference::InferBatchNormTrainingShape(
1041                         batch_norm_training->operand(0)->shape(),
1042                         batch_norm_training->operand(1)->shape(),
1043                         batch_norm_training->operand(2)->shape(),
1044                         batch_norm_training->feature_index()));
1045 }
1046 
HandleBatchNormInference(HloInstruction * batch_norm_inference)1047 Status ShapeVerifier::HandleBatchNormInference(
1048     HloInstruction* batch_norm_inference) {
1049   return CheckShape(batch_norm_inference,
1050                     ShapeInference::InferBatchNormInferenceShape(
1051                         batch_norm_inference->operand(0)->shape(),
1052                         batch_norm_inference->operand(1)->shape(),
1053                         batch_norm_inference->operand(2)->shape(),
1054                         batch_norm_inference->operand(3)->shape(),
1055                         batch_norm_inference->operand(4)->shape(),
1056                         batch_norm_inference->feature_index()));
1057 }
1058 
HandleBatchNormGrad(HloInstruction * batch_norm_grad)1059 Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
1060   return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
1061                                          batch_norm_grad->operand(0)->shape(),
1062                                          batch_norm_grad->operand(1)->shape(),
1063                                          batch_norm_grad->operand(2)->shape(),
1064                                          batch_norm_grad->operand(3)->shape(),
1065                                          batch_norm_grad->operand(4)->shape(),
1066                                          batch_norm_grad->feature_index()));
1067 }
1068 
1069 namespace {
1070 
1071 // Checks that the instruction does not have mixed precision floating point
1072 // inputs.
CheckMixedPrecisionOperands(const HloInstruction * instruction)1073 Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
1074   switch (instruction->opcode()) {
1075     // Allow-list the following opcodes for mixed-precision check, because
1076     // they involve data pass through or grouping via tuples, where the
1077     // precisions of buffers can be different.
1078     case HloOpcode::kCall:
1079     case HloOpcode::kConditional:
1080     case HloOpcode::kConstant:
1081     case HloOpcode::kConvolution:
1082     case HloOpcode::kDot:
1083     case HloOpcode::kAllReduce:
1084     case HloOpcode::kCopyDone:
1085     case HloOpcode::kCopyStart:
1086     case HloOpcode::kCustomCall:
1087     case HloOpcode::kDomain:
1088     case HloOpcode::kFusion:
1089     case HloOpcode::kGetTupleElement:
1090     case HloOpcode::kInfeed:
1091     case HloOpcode::kOutfeed:
1092     case HloOpcode::kParameter:
1093     case HloOpcode::kRecv:
1094     case HloOpcode::kRecvDone:
1095     case HloOpcode::kReducePrecision:
1096     case HloOpcode::kReduceWindow:
1097     case HloOpcode::kTupleSelect:
1098     case HloOpcode::kSend:
1099     case HloOpcode::kSendDone:
1100     case HloOpcode::kSort:
1101     case HloOpcode::kTuple:
1102     case HloOpcode::kWhile:
1103       break;
1104     default: {
1105       PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
1106       for (auto operand : instruction->operands()) {
1107         TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1108             operand->shape(),
1109             [&](const Shape& subshape, const ShapeIndex& index) {
1110               if (!ShapeUtil::ElementIsFloating(subshape)) {
1111                 return Status::OK();
1112               }
1113               if (fp_type == PRIMITIVE_TYPE_INVALID) {
1114                 fp_type = subshape.element_type();
1115               } else if (fp_type != subshape.element_type()) {
1116                 return InternalError(
1117                     "Seen floating point types of different precisions in "
1118                     "%s, but mixed precision is disallowed.",
1119                     instruction->ToString());
1120               }
1121               return Status::OK();
1122             }));
1123       }
1124     }
1125   }
1126   return Status::OK();
1127 }
1128 
1129 }  // namespace
1130 
HandleGather(HloInstruction * gather)1131 Status ShapeVerifier::HandleGather(HloInstruction* gather) {
1132   return CheckShape(
1133       gather,
1134       ShapeInference::InferGatherShape(
1135           gather->operand(0)->shape(), gather->operand(1)->shape(),
1136           gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
1137 }
1138 
HandleScatter(HloInstruction * scatter)1139 Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
1140   return CheckShape(
1141       scatter, ShapeInference::InferScatterShape(
1142                    scatter->operand(0)->shape(), scatter->operand(1)->shape(),
1143                    scatter->operand(2)->shape(),
1144                    scatter->to_apply()->ComputeProgramShape(),
1145                    scatter->scatter_dimension_numbers()));
1146 }
1147 
HandleAfterAll(HloInstruction * token)1148 Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
1149   std::vector<const Shape*> operand_shapes;
1150   for (const HloInstruction* operand : token->operands()) {
1151     operand_shapes.push_back(&operand->shape());
1152   }
1153   return CheckShape(token, ShapeUtil::MakeTokenShape());
1154 }
1155 
HandleAddDependency(HloInstruction * add_dependency)1156 Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) {
1157   TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1));
1158   return CheckShape(add_dependency, add_dependency->operand(0)->shape());
1159 }
1160 
HandleGetDimensionSize(HloInstruction * get_size)1161 Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
1162   return CheckShape(get_size,
1163                     ShapeInference::InferGetDimensionSizeShape(
1164                         get_size->operand(0)->shape(), get_size->dimension()));
1165 }
1166 
HandleSetDimensionSize(HloInstruction * set_size)1167 Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
1168   return CheckShape(set_size,
1169                     ShapeInference::InferSetDimensionSizeShape(
1170                         set_size->operand(0)->shape(),
1171                         set_size->operand(1)->shape(), set_size->dimension()));
1172 }
1173 
CheckShape(const HloInstruction * instruction,const Shape & inferred_shape,bool only_compare_minor_to_major_in_layout)1174 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
1175                                  const Shape& inferred_shape,
1176                                  bool only_compare_minor_to_major_in_layout) {
1177   // If allow_mixed_precision_ is false, check if there are operands with
1178   // different precisions. We need this check because ShapeInference allows
1179   // mixed precision inputs.
1180   if (!allow_mixed_precision_) {
1181     TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
1182   }
1183 
1184   // Check if the output shape matches the expected shape.
1185   //
1186   // We treat BF16 and F32 as compatible types if mixed precision is allowed,
1187   // but only when the instruction defines the BF16/F32 buffer.
1188   bool equal = [&] {
1189     switch (instruction->opcode()) {
1190       // The opcodes below can't have implicit layout conversions, nor can they
1191       // implicitly transform f32 -> bf16.  Fundamentally these are either
1192       // reinterpreting existing data (e.g. kBitcast) or shuffling data around
1193       // without modifying it (e.g. kGetTupleElement, kTupleSelect).
1194       case HloOpcode::kBitcast:
1195       case HloOpcode::kCall:
1196       case HloOpcode::kConditional:
1197       case HloOpcode::kConstant:
1198       case HloOpcode::kCopyDone:
1199       case HloOpcode::kCopyStart:
1200       case HloOpcode::kCustomCall:
1201       case HloOpcode::kDynamicUpdateSlice:
1202       case HloOpcode::kGetTupleElement:
1203       case HloOpcode::kInfeed:
1204       case HloOpcode::kOutfeed:
1205       case HloOpcode::kParameter:
1206       case HloOpcode::kRecv:
1207       case HloOpcode::kRecvDone:
1208       case HloOpcode::kSend:
1209       case HloOpcode::kSendDone:
1210       case HloOpcode::kTuple:
1211       case HloOpcode::kTupleSelect:
1212       case HloOpcode::kWhile:
1213         return ShapesSame(instruction->shape(), inferred_shape,
1214                           only_compare_minor_to_major_in_layout);
1215 
1216       // We allow arbitrary layout and f32->bf16 transformations on all other
1217       // instructions, although this may be made more strict pending discussion
1218       // in b/112709536.
1219       default:
1220         if (allow_mixed_precision_) {
1221           return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
1222                                                           inferred_shape);
1223         } else {
1224           return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
1225         }
1226     }
1227   }();
1228   if (!equal) {
1229     return InternalError(
1230         "Expected instruction to have shape equal to %s, actual "
1231         "shape is %s:\n%s",
1232         StringifyShape(inferred_shape), StringifyShape(instruction->shape()),
1233         instruction->ToString());
1234   }
1235   return Status::OK();
1236 }
1237 
CheckShape(const HloInstruction * instruction,const StatusOr<Shape> & inferred_shape_status)1238 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
1239                                  const StatusOr<Shape>& inferred_shape_status) {
1240   if (!inferred_shape_status.ok()) {
1241     Status s = inferred_shape_status.status();
1242     tensorflow::errors::AppendToMessage(&s, ", for instruction ",
1243                                         instruction->ToString());
1244     return s;
1245   }
1246   return CheckShape(instruction, inferred_shape_status.ValueOrDie());
1247 }
1248 
CheckUnaryShape(const HloInstruction * instruction)1249 Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
1250   return CheckShape(instruction,
1251                     ShapeInference::InferUnaryOpShape(instruction->opcode(),
1252                                                       instruction->operand(0)));
1253 }
1254 
CheckBinaryShape(const HloInstruction * instruction)1255 Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
1256   return CheckShape(
1257       instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
1258                                                       instruction->operand(0),
1259                                                       instruction->operand(1)));
1260 }
1261 
CheckTernaryShape(const HloInstruction * instruction)1262 Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
1263   return CheckShape(instruction,
1264                     ShapeInference::InferTernaryOpShape(
1265                         instruction->opcode(), instruction->operand(0),
1266                         instruction->operand(1), instruction->operand(2)));
1267 }
1268 
CheckVariadicShape(const HloInstruction * instruction)1269 Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
1270   return CheckShape(instruction,
1271                     ShapeInference::InferVariadicOpShape(
1272                         instruction->opcode(), instruction->operands()));
1273 }
1274 
VerifyEntryComputationLayout(const HloModule & module)1275 Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
1276   const HloComputation* computation = module.entry_computation();
1277   const auto& layout = module.entry_computation_layout();
1278   const ShapeLayout& result_layout = layout.result_layout();
1279 
1280   TF_RETURN_IF_ERROR(
1281       ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));
1282 
1283   if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
1284                              result_layout.shape())) {
1285     return InternalError(
1286         "Shape of the root instruction of entry computation (%s) should be "
1287         "compatible to one specified in module's entry computation layout (%s)",
1288         ShapeUtil::HumanString(computation->root_instruction()->shape()),
1289         ShapeUtil::HumanString(result_layout.shape()));
1290   }
1291 
1292   if (computation->num_parameters() != layout.parameter_count()) {
1293     return InternalError(
1294         "Number of parameters in entry computation layout (%d) must be same "
1295         "as number of parameters of entry computation (%d)",
1296         layout.parameter_count(), computation->num_parameters());
1297   }
1298 
1299   for (int i = 0; i < computation->num_parameters(); ++i) {
1300     const HloInstruction* parameter = computation->parameter_instruction(i);
1301     TF_RETURN_IF_ERROR(
1302         ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
1303     if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
1304       return InternalError(
1305           "Shape of the entry computation parameter %d is %s should be "
1306           "compatible to the one specified in module's entry computation "
1307           "layout %s",
1308           i, ShapeUtil::HumanString(parameter->shape()),
1309           ShapeUtil::HumanString(layout.parameter_shape(i)));
1310     }
1311   }
1312 
1313   return Status::OK();
1314 }
1315 
ComputationsToString(absl::Span<HloComputation * const> computations)1316 string ComputationsToString(absl::Span<HloComputation* const> computations) {
1317   return absl::StrJoin(computations, ",",
1318                        [](string* s, const HloComputation* computation) {
1319                          s->append(computation->name());
1320                        });
1321 }
1322 
1323 // Verifies various invariants about the structure of the HLO:
1324 //
1325 // (1) each instruction has a non-null parent() set to the HloComputation
1326 // which
1327 //     contains it.
1328 //
1329 // (2) each computation has a non-null parent() set to the HloModule which
1330 //     contains it.
1331 //
1332 // (3) the operands of each instruction are in the same computation as the
1333 //     instruction.
VerifyHloStructure(HloModule * module)1334 Status VerifyHloStructure(HloModule* module) {
1335   for (const HloComputation* computation : module->computations()) {
1336     if (computation->parent() == nullptr) {
1337       return InternalError("Computation %s has a null parent pointer",
1338                            computation->name());
1339     }
1340     if (computation->parent() != module) {
1341       return InternalError(
1342           "Computation %s parent() does not point to parent module",
1343           computation->name());
1344     }
1345 
1346     for (const HloInstruction* instruction : computation->instructions()) {
1347       if (instruction->parent() == nullptr) {
1348         return InternalError("Instruction %s has a null parent pointer",
1349                              instruction->name());
1350       }
1351       if (instruction->parent() != computation) {
1352         return InternalError(
1353             "Instruction %s parent() does not point to parent computation",
1354             instruction->name());
1355       }
1356     }
1357   }
1358 
1359   // Check that operands are in the same computation separately from verifying
1360   // parent() correctness so conditions like a null HloInstruction::parent()
1361   // are identified and reported explicitly above rather than reporting a
1362   // mismatched operand.
1363   for (const HloComputation* computation : module->computations()) {
1364     for (const HloInstruction* instruction : computation->instructions()) {
1365       for (int i = 0; i < instruction->operand_count(); ++i) {
1366         const HloInstruction* operand = instruction->operand(i);
1367         if (operand->parent() != instruction->parent()) {
1368           return InternalError(
1369               "Operand %d (%s) of instruction %s is in a different "
1370               "computation: %s vs %s",
1371               i, operand->name(), instruction->name(),
1372               operand->parent() ? operand->parent()->name() : "(null)",
1373               instruction->parent()->name());
1374         }
1375       }
1376     }
1377   }
1378   return Status::OK();
1379 }
1380 
1381 namespace {
1382 
1383 // Returns true if the given Shape has a TOKEN shape as any subshape.
ShapeContainsToken(const Shape & shape)1384 bool ShapeContainsToken(const Shape& shape) {
1385   bool contains_token = false;
1386   ShapeUtil::ForEachSubshape(
1387       shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
1388         if (subshape.IsToken()) {
1389           contains_token = true;
1390         }
1391       });
1392   return contains_token;
1393 }
1394 
1395 // Verifies that all types entering and exiting the entry computation are
1396 // legal.
VerifyEntryAndExitShapes(const HloModule & module)1397 Status VerifyEntryAndExitShapes(const HloModule& module) {
1398   // Tokens cannot be passed as entry parameters.
1399   // TODO(b/80000000): Remove this constraint.
1400   for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
1401     HloInstruction* param =
1402         module.entry_computation()->parameter_instruction(i);
1403     if (ShapeContainsToken(param->shape())) {
1404       return InternalError(
1405           "Entry parameter %d is or contains a token shape: %s", i,
1406           ShapeUtil::HumanString(param->shape()));
1407     }
1408   }
1409   return Status::OK();
1410 }
1411 
1412 // Checks if the given two instructions share the same channel id.
CheckSameChannel(const HloInstruction * instr1,const HloInstruction * instr2)1413 Status CheckSameChannel(const HloInstruction* instr1,
1414                         const HloInstruction* instr2) {
1415   if (instr1->channel_id() != instr2->channel_id()) {
1416     return InternalError(
1417         "Expected to have the same channel id, actual channel ids are: %s "
1418         "(%d), %s (%d)",
1419         instr1->ToString(), *instr1->channel_id(), instr2->ToString(),
1420         *instr2->channel_id());
1421   }
1422   return Status::OK();
1423 }
1424 
1425 // Checks if the given two instructions have the same is_host_transfer
1426 // attribute value. Intsructions must be send/recv instructions or their
1427 // 'done' variant.
CheckSameIsHostTransfer(const HloInstruction * instr1,const HloInstruction * instr2)1428 Status CheckSameIsHostTransfer(const HloInstruction* instr1,
1429                                const HloInstruction* instr2) {
1430   const HloSendRecvInstruction* send_recv1 =
1431       DynCast<const HloSendRecvInstruction>(instr1);
1432   const HloSendRecvInstruction* send_recv2 =
1433       DynCast<const HloSendRecvInstruction>(instr2);
1434   TF_RET_CHECK(send_recv1 != nullptr);
1435   TF_RET_CHECK(send_recv2 != nullptr);
1436   if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
1437     return InternalError(
1438         "Expected instructions to have the same is-host-transfer property: "
1439         "%s, "
1440         "%s ",
1441         instr1->ToString(), instr2->ToString());
1442   }
1443   return Status::OK();
1444 }
1445 
VerifySingleUser(const HloInstruction * instruction,HloOpcode expected_user)1446 Status VerifySingleUser(const HloInstruction* instruction,
1447                         HloOpcode expected_user) {
1448   TF_RET_CHECK(instruction->users().size() == 1)
1449       << "The " << HloOpcodeString(instruction->opcode())
1450       << " instruction requires one consumer, found "
1451       << instruction->users().size();
1452 
1453   const HloInstruction* user = instruction->users().front();
1454   TF_RET_CHECK(user->opcode() == expected_user)
1455       << "The consumer of a " << HloOpcodeString(instruction->opcode())
1456       << " instruction needs to be " << HloOpcodeString(expected_user)
1457       << ", found " << HloOpcodeString(user->opcode());
1458   return Status::OK();
1459 }
1460 
VerifySingleOperand(const HloInstruction * instruction,HloOpcode expected_operand)1461 Status VerifySingleOperand(const HloInstruction* instruction,
1462                            HloOpcode expected_operand) {
1463   TF_RET_CHECK(instruction->operands().size() == 1)
1464       << "The " << HloOpcodeString(instruction->opcode())
1465       << " instruction requires one consumer, found "
1466       << instruction->users().size();
1467 
1468   const HloInstruction* operand = instruction->operand(0);
1469   TF_RET_CHECK(operand->opcode() == expected_operand)
1470       << "The operand of a " << HloOpcodeString(instruction->opcode())
1471       << " instruction needs to be " << HloOpcodeString(expected_operand)
1472       << ", found " << HloOpcodeString(operand->opcode());
1473   return Status::OK();
1474 }
1475 
1476 // Checks asynchronous instruction pairs.
VerifyAsynchronousInstructionPairs(const HloModule & module)1477 Status VerifyAsynchronousInstructionPairs(const HloModule& module) {
1478   // CopyStart must have a single CopyDone user.
1479   for (const HloComputation* computation : module.computations()) {
1480     for (const HloInstruction* instruction : computation->instructions()) {
1481       switch (instruction->opcode()) {
1482         case HloOpcode::kCopyStart: {
1483           TF_RETURN_IF_ERROR(
1484               VerifySingleUser(instruction, HloOpcode::kCopyDone));
1485           break;
1486         }
1487         case HloOpcode::kCopyDone: {
1488           TF_RETURN_IF_ERROR(
1489               VerifySingleOperand(instruction, HloOpcode::kCopyStart));
1490           break;
1491         }
1492         case HloOpcode::kCollectivePermuteStart: {
1493           TF_RETURN_IF_ERROR(
1494               VerifySingleUser(instruction, HloOpcode::kCollectivePermuteDone));
1495           break;
1496         }
1497         case HloOpcode::kCollectivePermuteDone: {
1498           TF_RETURN_IF_ERROR(VerifySingleOperand(
1499               instruction, HloOpcode::kCollectivePermuteStart));
1500           break;
1501         }
1502         default:
1503           break;
1504       }
1505     }
1506   }
1507   return Status::OK();
1508 }
1509 
1510 // Checks that AllReduce instructions in the module are either all layout
1511 // constrained or all unconstrained.
VerifyLayoutConstrainedAllReduce(const HloModule & module)1512 Status VerifyLayoutConstrainedAllReduce(const HloModule& module) {
1513   const HloAllReduceInstruction* reference = nullptr;
1514   for (const HloComputation* computation : module.computations()) {
1515     for (const HloInstruction* instruction : computation->instructions()) {
1516       if (instruction->opcode() != HloOpcode::kAllReduce) {
1517         continue;
1518       }
1519       auto all_reduce = DynCast<HloAllReduceInstruction>(instruction);
1520       if (!reference) {
1521         reference = all_reduce;
1522       }
1523       if (reference->constrain_layout() != all_reduce->constrain_layout()) {
1524         return FailedPrecondition(
1525             "HloModule has a mix of layout constrained and unconstrained "
1526             "AllReduce instructions.");
1527       }
1528     }
1529   }
1530   return Status::OK();
1531 }
1532 
1533 // Checks various invariants of channel instructions (send/recv and
1534 // collectives).
VerifyChannels(const HloModule & module)1535 Status VerifyChannels(const HloModule& module) {
1536   absl::flat_hash_map<int64, std::vector<const HloInstruction*>>
1537       channel_instructions;
1538 
1539   // Send/Recv instruction must have a single user: the corresponding
1540   // SendDone/RecvDone. with matching channel.
1541   for (const HloComputation* computation : module.computations()) {
1542     for (const HloInstruction* instruction : computation->instructions()) {
1543       auto channel_instr = DynCast<HloChannelInstruction>(instruction);
1544       if (!channel_instr || !channel_instr->channel_id()) {
1545         continue;
1546       }
1547       channel_instructions[*channel_instr->channel_id()].push_back(instruction);
1548 
1549       switch (instruction->opcode()) {
1550         case HloOpcode::kSend: {
1551           TF_RET_CHECK(instruction->users().size() == 1);
1552           const HloInstruction* send_done = instruction->users().front();
1553           TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
1554           TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
1555           TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
1556           break;
1557         }
1558         case HloOpcode::kRecv: {
1559           TF_RET_CHECK(instruction->users().size() == 1);
1560           const HloInstruction* recv_done = instruction->users().front();
1561           TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
1562           TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
1563           TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
1564           break;
1565         }
1566         case HloOpcode::kSendDone:
1567           TF_RET_CHECK(instruction->operands().size() == 1);
1568           TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
1569           break;
1570         case HloOpcode::kRecvDone:
1571           TF_RET_CHECK(instruction->operands().size() == 1);
1572           TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
1573           break;
1574         default:
1575           break;
1576       }
1577     }
1578   }
1579 
1580   // Iterate over each channel to check invariants.
1581   for (auto& pair : channel_instructions) {
1582     auto& instructions = pair.second;
1583     const HloInstruction* first = instructions[0];
1584     auto sendrecv = DynCast<HloSendRecvInstruction>(first);
1585     if (sendrecv) {
1586       absl::flat_hash_set<HloOpcode> opcodes;
1587       for (const HloInstruction* instr : instructions) {
1588         opcodes.insert(instr->opcode());
1589         auto cast = DynCast<HloSendRecvInstruction>(instr);
1590         TF_RET_CHECK(cast != nullptr)
1591             << "channel " << pair.first
1592             << " is used for different types of channel instructions";
1593       }
1594       if (sendrecv->is_host_transfer()) {
1595         TF_RET_CHECK(instructions.size() == 2)
1596             << "channel " << pair.first
1597             << " is used for multiple host send/recv instructions";
1598       } else {
1599         TF_RET_CHECK(instructions.size() == opcodes.size())
1600             << "channel " << pair.first
1601             << " is used for multiple send/recv instructions";
1602       }
1603     } else {
1604       for (const HloInstruction* instr : instructions) {
1605         TF_RET_CHECK(first->opcode() == instr->opcode())
1606             << "channel " << pair.first
1607             << " is used for different types of channel instructions";
1608       }
1609     }
1610   }
1611 
1612   return Status::OK();
1613 }
1614 
1615 // CHECKs various invariants of a fusion instruction.
CheckFusionInstruction(HloInstruction * fusion)1616 Status CheckFusionInstruction(HloInstruction* fusion) {
1617   // The parent fusion instruction of the fusion computation must be 'fusion'.
1618   HloComputation* fused_computation = fusion->fused_instructions_computation();
1619   if (fusion != fused_computation->FusionInstruction()) {
1620     return InternalError(
1621         "Instruction of fused computation does not match expected "
1622         "instruction "
1623         "%s.",
1624         fusion->ToString());
1625   }
1626 
1627   // Fused root instruction and fused parameters must all be owned by the
1628   // fusion computation.
1629   bool root_owned = false;
1630   const std::vector<HloInstruction*>& fused_parameters =
1631       fusion->fused_parameters();
1632   const HloInstruction* fused_root = fusion->fused_expression_root();
1633   std::vector<bool> parameter_owned(fused_parameters.size(), false);
1634   for (auto* instruction : fused_computation->instructions()) {
1635     if (fused_root == instruction) {
1636       if (root_owned) {
1637         return InternalError("Root appears more than once in %s.",
1638                              fusion->ToString());
1639       }
1640       root_owned = true;
1641     }
1642     for (int i = 0; i < fused_parameters.size(); ++i) {
1643       if (fused_parameters[i] == instruction) {
1644         if (parameter_owned[i]) {
1645           return InternalError("Parameter appears more than once in %s.",
1646                                fusion->ToString());
1647         }
1648         parameter_owned[i] = true;
1649       }
1650     }
1651   }
1652   if (!root_owned) {
1653     return InternalError("Root not found in computation of %s.",
1654                          fusion->ToString());
1655   }
1656   // Make sure all the parameter_owned entries are set
1657   for (int i = 0; i < parameter_owned.size(); i++) {
1658     if (!parameter_owned[i]) {
1659       return InternalError("Parameter %d not found in computation of %s.", i,
1660                            fusion->ToString());
1661     }
1662   }
1663 
1664   // Fused root must have no users.
1665   if (fused_root->user_count() != 0) {
1666     return InternalError("Root of %s may not have users.", fusion->ToString());
1667   }
1668 
1669   // All uses of fused instructions must be in the fusion computation, and
1670   // every non-root instruction must have at least one use.
1671   for (auto* instruction :
1672        fusion->fused_instructions_computation()->instructions()) {
1673     if (instruction != fused_root) {
1674       if (instruction->user_count() == 0) {
1675         return InternalError("Non-root instruction %s in %s must have users.",
1676                              instruction->ToString(), fusion->ToString());
1677       }
1678       for (auto& user : instruction->users()) {
1679         if (fused_computation != user->parent()) {
1680           return InternalError(
1681               "Non-root instruction %s in %s may not have external users.",
1682               instruction->ToString(), fusion->ToString());
1683         }
1684       }
1685     }
1686   }
1687 
1688   // Fused parameter instructions must be numbered contiguously and match up
1689   // (shapes equal) with their respective operand.
1690   CHECK_EQ(fusion->operands().size(), fused_parameters.size());
1691   std::vector<bool> parameter_numbers(fused_parameters.size(), false);
1692   for (auto fused_param : fused_parameters) {
1693     int64 param_no = fused_param->parameter_number();
1694     if (param_no < 0) {
1695       return InternalError("Unexpected negative parameter number %d in %s.",
1696                            param_no, fusion->ToString());
1697     }
1698     if (param_no >= fused_parameters.size()) {
1699       return InternalError(
1700           "Unexpected parameter number %d in %s: higher then number of "
1701           "parameters %lu.",
1702           param_no, fusion->ToString(), fused_parameters.size());
1703     }
1704     if (parameter_numbers[param_no]) {
1705       return InternalError(
1706           "Did not expect parameter number %d more than once in %s.", param_no,
1707           fusion->ToString());
1708     }
1709     parameter_numbers[param_no] = true;
1710   }
1711   // Make sure all the parameter_numbers entries were seen.
1712   for (int i = 0; i < parameter_numbers.size(); i++) {
1713     if (!parameter_numbers[i]) {
1714       return InternalError("Did not see parameter number %d in %s.", i,
1715                            fusion->ToString());
1716     }
1717   }
1718 
1719   TF_RET_CHECK(fusion->called_computations() ==
1720                absl::Span<HloComputation* const>(
1721                    {fusion->fused_instructions_computation()}))
1722       << "Fusion HLO calls computations other than the "
1723          "fused_instructions_computation: "
1724       << fusion->ToString() << " fusion->fused_instructions_computation(): "
1725       << fusion->fused_instructions_computation()->ToString()
1726       << " fusion->called_computations(): "
1727       << ComputationsToString(fusion->called_computations());
1728 
1729   for (const auto& fused : fusion->fused_instructions()) {
1730     TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
1731         << "Fused HLO was missing a parent: " << fused->ToString()
1732         << " parent: " << fused->parent()
1733         << " computation: " << fusion->parent();
1734   }
1735 
1736   // TODO(b/65423525): We'd like to check that all operands are distinct.
1737   // This is currently disabled due to the invariant being violated by
1738   // multi-output fusion.
1739   return Status::OK();
1740 }
1741 
1742 // Checks that the operand shapes are compatible to the output shape, i.e.,
1743 // that there are no implicit broadcasts.
CheckElementwiseInstruction(HloInstruction * instruction)1744 Status CheckElementwiseInstruction(HloInstruction* instruction) {
1745   const Shape& out_shape = instruction->shape();
1746   for (HloInstruction* operand : instruction->operands()) {
1747     const Shape& operand_shape = operand->shape();
1748     if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
1749       return FailedPrecondition(
1750           "Implicit broadcast is not allowed in HLO."
1751           "Found different shapes for instruction %s.\n"
1752           "output: %s\noperand: %s\n",
1753           HloOpcodeString(instruction->opcode()),
1754           ShapeUtil::HumanString(out_shape),
1755           ShapeUtil::HumanString(operand_shape));
1756     }
1757   }
1758   if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) {
1759     const Shape& operand_shape = comparison->operand(1)->shape();
1760     PrimitiveType operand_element_type = operand_shape.element_type();
1761     Comparison::Type default_comparison_type =
1762         Comparison::DefaultComparisonType(operand_element_type);
1763     if (primitive_util::IsFloatingPointType(operand_element_type)) {
1764       if (comparison->type() != Comparison::Type::kFloat &&
1765           comparison->type() != Comparison::Type::kFloatTotalOrder) {
1766         return FailedPrecondition(
1767             "Expected comparison type %s or %s.\n"
1768             "actual: %s\noperand: %s\n",
1769             ComparisonTypeToString(Comparison::Type::kFloat),
1770             ComparisonTypeToString(Comparison::Type::kFloatTotalOrder),
1771             ComparisonTypeToString(comparison->type()),
1772             ShapeUtil::HumanString(operand_shape));
1773       }
1774     } else if (comparison->type() != default_comparison_type) {
1775       return FailedPrecondition(
1776           "Expected comparison type %s.\n"
1777           "actual: %s\noperand: %s\n",
1778           ComparisonTypeToString(default_comparison_type),
1779           ComparisonTypeToString(comparison->type()),
1780           ShapeUtil::HumanString(operand_shape));
1781     }
1782   }
1783   return Status::OK();
1784 }
1785 
1786 // Visitor which verifies various fields on the HLO instruction. This class does
1787 // not check result shape as that is checked in the ShapeVerifier.
1788 class InstructionVerifier : public DfsHloVisitorWithDefault {
1789  public:
InstructionVerifier(std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)1790   explicit InstructionVerifier(std::function<bool(const HloInstruction*)>
1791                                    instruction_can_change_layout_func)
1792       : instruction_can_change_layout_func_(
1793             instruction_can_change_layout_func) {}
1794 
DefaultAction(HloInstruction *)1795   Status DefaultAction(HloInstruction*) override { return Status::OK(); }
1796 
HandleFusion(HloInstruction * fusion)1797   Status HandleFusion(HloInstruction* fusion) override {
1798     return CheckFusionInstruction(fusion);
1799   }
1800 
HandleBroadcast(HloInstruction * broadcast)1801   Status HandleBroadcast(HloInstruction* broadcast) override {
1802     // If you see this failure then someone has confused the difference
1803     // between the HLO broadcast op, and the UserComputation broadcast
1804     // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
1805     // or ComputationLowerer::Visit()
1806     TF_RET_CHECK(broadcast->dimensions().size() ==
1807                  broadcast->operand(0)->shape().rank())
1808         << "Broadcast HLO (" << broadcast->ToShortString()
1809         << ") has invalid number of dimensions: "
1810         << broadcast->dimensions().size()
1811         << " != " << broadcast->operand(0)->shape().rank();
1812     return Status::OK();
1813   }
1814 
HandleWhile(HloInstruction * xla_while)1815   Status HandleWhile(HloInstruction* xla_while) override {
1816     auto* while_cond = xla_while->while_condition();
1817     auto* while_body = xla_while->while_body();
1818     if (while_cond->num_parameters() != 1) {
1819       return FailedPrecondition(
1820           "While condition must have exactly 1 parameter; had %d : %s",
1821           while_cond->num_parameters(), while_cond->ToString());
1822     }
1823     if (while_body->num_parameters() != 1) {
1824       return FailedPrecondition(
1825           "While body must have exactly 1 parameter; had %d : %s",
1826           while_body->num_parameters(), while_body->ToString());
1827     }
1828     if (xla_while->operand_count() != 1) {
1829       return FailedPrecondition(
1830           "While loop must have exactly one operand; had %d : %s",
1831           xla_while->operand_count(), xla_while->ToString());
1832     }
1833     return Status::OK();
1834   }
1835 
HandleConditional(HloInstruction * conditional)1836   Status HandleConditional(HloInstruction* conditional) override {
1837     for (int b = 0; b < conditional->branch_count(); ++b) {
1838       if (conditional->branch_computation(b)->num_parameters() != 1) {
1839         return FailedPrecondition(
1840             "Branch computation %s of %s must have 1 parameter instead of %d",
1841             conditional->branch_computation(b)->name(), conditional->ToString(),
1842             conditional->branch_computation(b)->num_parameters());
1843       }
1844     }
1845     return Status::OK();
1846   }
1847 
HandleElementwiseUnary(HloInstruction * instruction)1848   Status HandleElementwiseUnary(HloInstruction* instruction) override {
1849     return CheckElementwiseInstruction(instruction);
1850   }
1851 
HandleElementwiseBinary(HloInstruction * instruction)1852   Status HandleElementwiseBinary(HloInstruction* instruction) override {
1853     return CheckElementwiseInstruction(instruction);
1854   }
1855 
HandleGetTupleElement(HloInstruction * gte)1856   Status HandleGetTupleElement(HloInstruction* gte) override {
1857     TF_RET_CHECK(gte->operand(0)->shape().IsTuple());
1858     return Status::OK();
1859   }
1860 
HandleTranspose(HloInstruction * transpose)1861   Status HandleTranspose(HloInstruction* transpose) override {
1862     const Shape& shape = transpose->shape();
1863     const HloInstruction* operand = transpose->operand(0);
1864     TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
1865     TF_RET_CHECK(shape.dimensions().size() ==
1866                  transpose->operand(0)->shape().dimensions().size());
1867     TF_RET_CHECK(std::equal(
1868         shape.dimensions().begin(), shape.dimensions().end(),
1869         Permute(operand->shape().dimensions(), transpose->dimensions())
1870             .begin()))
1871         << "shape: " << shape << ", operand->shape(): " << shape
1872         << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
1873         << "}";
1874     return Status::OK();
1875   }
1876 
HandleAllReduce(HloInstruction * crs)1877   Status HandleAllReduce(HloInstruction* crs) override {
1878     if (crs->channel_id().has_value()) {
1879       TF_RET_CHECK(crs->channel_id().value() > 0)
1880           << "All reduce channel id must be greater than 0 for "
1881           << crs->ToShortString();
1882     }
1883     return Status::OK();
1884   }
1885 
Preprocess(HloInstruction * instruction)1886   Status Preprocess(HloInstruction* instruction) override {
1887     auto previous = instructions_by_name_.find(instruction->name());
1888     TF_RET_CHECK(previous == instructions_by_name_.end())
1889         << "HLO has name that is not unique within module:\n"
1890         << instruction->ToString()
1891         << " in computation: " << instruction->parent()->name()
1892         << "\nPrevious HLO with same name:\n"
1893         << previous->second->ToString()
1894         << " in computation: " << previous->second->parent()->name();
1895     instructions_by_name_[instruction->name()] = instruction;
1896     return Status::OK();
1897   }
1898 
Postprocess(HloInstruction * instruction)1899   Status Postprocess(HloInstruction* instruction) override {
1900     if (instruction_can_change_layout_func_ &&
1901         LayoutUtil::IsDenseArray(instruction->shape()) &&
1902         !instruction_can_change_layout_func_(instruction)) {
1903       const Shape& result_shape = instruction->shape();
1904       const Layout& result_layout = result_shape.layout();
1905       for (HloInstruction* operand : instruction->operands()) {
1906         const Shape& operand_shape = operand->shape();
1907         if (LayoutUtil::IsDenseArray(operand_shape) &&
1908             operand_shape.rank() == result_shape.rank()) {
1909           const Layout& operand_layout = operand_shape.layout();
1910           TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
1911               << "Instruction shouldn't change layouts "
1912               << instruction->ToString() << " From " << result_shape << " To "
1913               << operand_shape;
1914         }
1915       }
1916     }
1917 
1918     return Status::OK();
1919   }
1920 
1921  private:
1922   absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
1923   // Determines whether an instruction can change layouts.
1924   std::function<bool(const HloInstruction*)>
1925       instruction_can_change_layout_func_;
1926 };
1927 
1928 }  // namespace
1929 
Run(HloModule * module)1930 StatusOr<bool> HloVerifier::Run(HloModule* module) {
1931   TF_RET_CHECK(!module->name().empty());
1932 
1933   if (module->entry_computation()->IsFusionComputation()) {
1934     return InvalidArgument(
1935         "Module entry computation cannot be a fusion computation");
1936   }
1937 
1938   TF_RETURN_IF_ERROR(VerifyHloStructure(module));
1939   TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module));
1940   TF_RETURN_IF_ERROR(VerifyChannels(*module));
1941 
1942   std::unique_ptr<ShapeVerifier> shape_verifier =
1943       target_metadata_->GetVerifier();
1944   InstructionVerifier instruction_verifier(instruction_can_change_layout_func_);
1945   for (auto* computation : module->computations()) {
1946     TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
1947     TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
1948   }
1949 
1950   TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module));
1951   TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
1952 
1953   // If the module has a schedule, it must be valid.
1954   if (module->has_schedule()) {
1955     TF_RETURN_IF_ERROR(module->schedule().Verify());
1956   }
1957 
1958   TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
1959       *module, [this](const Shape& shape) -> int64 {
1960         if (target_metadata_->IsLayoutSensitive()) {
1961           return target_metadata_->ShapeSize(shape);
1962         } else {
1963           return 0;
1964         }
1965       }));
1966 
1967   TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
1968   TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module));
1969 
1970   return false;
1971 }
1972 
1973 }  // namespace xla
1974