1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <set>
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 
29 namespace xla {
30 
VerifyNotSparse(const Shape & shape)31 Status VerifyNotSparse(const Shape& shape) {
32   return ShapeUtil::ForEachSubshapeWithStatus(
33       shape, [](const Shape& subshape, const ShapeIndex&) -> Status {
34         if (LayoutUtil::IsSparseArray(subshape)) {
35           return InternalError("Sparse arrays are not yet fully supported: %s",
36                                ShapeUtil::HumanStringWithLayout(subshape));
37         }
38         return Status::OK();
39       });
40 }
41 
IsCallerInstruction(HloInstruction * hlo)42 bool IsCallerInstruction(HloInstruction* hlo) {
43   switch (hlo->opcode()) {
44     case HloOpcode::kCall:
45     case HloOpcode::kConditional:
46     case HloOpcode::kWhile:
47     case HloOpcode::kAllReduce:
48     case HloOpcode::kMap:
49     case HloOpcode::kReduce:
50     case HloOpcode::kReduceWindow:
51     case HloOpcode::kScatter:
52     case HloOpcode::kSelectAndScatter:
53     case HloOpcode::kSort:
54     case HloOpcode::kFusion:
55       return true;
56     default:
57       return false;
58   }
59 }
60 
61 namespace {
62 
CheckOperandCount(const HloInstruction * hlo,int expected)63 Status CheckOperandCount(const HloInstruction* hlo, int expected) {
64   if (hlo->operand_count() != expected) {
65     return InternalError("Expected %d operands for %s instruction: %s",
66                          expected, HloOpcodeString(hlo->opcode()),
67                          hlo->ToString());
68   }
69   return Status::OK();
70 }
71 
CheckParameterCount(const HloInstruction * calling_instruction,const HloComputation * computation,int expected)72 Status CheckParameterCount(const HloInstruction* calling_instruction,
73                            const HloComputation* computation, int expected) {
74   if (computation->num_parameters() != expected) {
75     return InternalError(
76         "Expected computation %s called from %s to have %d parameters, has %d",
77         computation->name(), calling_instruction->name(), expected,
78         computation->num_parameters());
79   }
80   return Status::OK();
81 }
82 
83 }  // namespace
84 
Preprocess(HloInstruction * hlo)85 Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
86   if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) {
87     return InternalError(
88         "Called computations specified for non-caller instruction  %s",
89         hlo->ToString());
90   }
91   TF_RETURN_IF_ERROR(VerifyNotSparse(hlo->shape()));
92 
93   absl::optional<int> arity = HloOpcodeArity(hlo->opcode());
94   if (arity) {
95     TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
96   }
97   return Status::OK();
98 }
99 
HandleElementwiseUnary(HloInstruction * hlo)100 Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
101   return CheckUnaryShape(hlo);
102 }
103 
HandleElementwiseBinary(HloInstruction * hlo)104 Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) {
105   return CheckBinaryShape(hlo);
106 }
107 
HandleClamp(HloInstruction * clamp)108 Status ShapeVerifier::HandleClamp(HloInstruction* clamp) {
109   return CheckTernaryShape(clamp);
110 }
111 
HandleSelect(HloInstruction * select)112 Status ShapeVerifier::HandleSelect(HloInstruction* select) {
113   return CheckTernaryShape(select);
114 }
115 
HandleTupleSelect(HloInstruction * tuple_select)116 Status ShapeVerifier::HandleTupleSelect(HloInstruction* tuple_select) {
117   return CheckTernaryShape(tuple_select);
118 }
119 
HandleConcatenate(HloInstruction * concatenate)120 Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
121   std::vector<const Shape*> operand_shapes;
122   for (const HloInstruction* operand : concatenate->operands()) {
123     operand_shapes.push_back(&operand->shape());
124   }
125   return CheckShape(concatenate,
126                     ShapeInference::InferConcatOpShape(
127                         operand_shapes, concatenate->concatenate_dimension()));
128 }
129 
HandleConvert(HloInstruction * convert)130 Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
131   return CheckShape(convert, ShapeInference::InferConvertShape(
132                                  convert->operand(0)->shape(),
133                                  convert->shape().element_type()));
134 }
135 
HandleBitcastConvert(HloInstruction * convert)136 Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
137   return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
138                                  convert->operand(0)->shape(),
139                                  convert->shape().element_type()));
140 }
141 
HandleCopy(HloInstruction * copy)142 Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
143   return CheckUnaryShape(copy);
144 }
145 
HandleDot(HloInstruction * dot)146 Status ShapeVerifier::HandleDot(HloInstruction* dot) {
147   TF_ASSIGN_OR_RETURN(const Shape expected,
148                       ShapeInference::InferDotOpShape(
149                           dot->operand(0)->shape(), dot->operand(1)->shape(),
150                           dot->dot_dimension_numbers()));
151   return CheckShape(dot, expected);
152 }
153 
HandleConvolution(HloInstruction * convolution)154 Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
155   TF_ASSIGN_OR_RETURN(
156       const Shape expected,
157       ShapeInference::InferConvolveShape(
158           convolution->operand(0)->shape(), convolution->operand(1)->shape(),
159           convolution->feature_group_count(), convolution->batch_group_count(),
160           convolution->window(), convolution->convolution_dimension_numbers()));
161   return CheckShape(convolution, expected);
162 }
163 
HandleFft(HloInstruction * fft)164 Status ShapeVerifier::HandleFft(HloInstruction* fft) {
165   TF_ASSIGN_OR_RETURN(
166       const Shape expected,
167       ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
168                                     fft->fft_length()));
169   return CheckShape(fft, expected);
170 }
171 
HandleTriangularSolve(HloInstruction * hlo)172 Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) {
173   TF_ASSIGN_OR_RETURN(const Shape expected,
174                       ShapeInference::InferTriangularSolveShape(
175                           hlo->operand(0)->shape(), hlo->operand(1)->shape(),
176                           hlo->triangular_solve_options()));
177   return CheckShape(hlo, expected);
178 }
179 
HandleCholesky(HloInstruction * hlo)180 Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
181   TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
182   TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape(
183                                                 hlo->operand(0)->shape()));
184   return CheckShape(hlo, expected);
185 }
186 
HandleAllReduce(HloInstruction * crs)187 Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) {
188   std::vector<const Shape*> operand_shapes;
189   for (const HloInstruction* operand : crs->operands()) {
190     operand_shapes.push_back(&operand->shape());
191   }
192   return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes));
193 }
194 
HandleAllToAll(HloInstruction * hlo)195 Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
196   std::vector<const Shape*> operand_shapes;
197   for (const HloInstruction* operand : hlo->operands()) {
198     operand_shapes.push_back(&operand->shape());
199   }
200   return CheckShape(hlo,
201                     ShapeInference::InferAllToAllTupleShape(operand_shapes));
202 }
203 
HandleReplicaId(HloInstruction * hlo)204 Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
205   return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
206 }
207 
HandleCollectivePermute(HloInstruction * hlo)208 Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
209   return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
210                              hlo->operand(0)->shape()));
211 }
212 
HandleReducePrecision(HloInstruction * reduce_precision)213 Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
214   return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
215                                           reduce_precision->operand(0)->shape(),
216                                           reduce_precision->exponent_bits(),
217                                           reduce_precision->mantissa_bits()));
218 }
219 
CheckIsTokenOperand(const HloInstruction * instruction,int64 operand_no)220 Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
221                                           int64 operand_no) {
222   const HloInstruction* token = instruction->operand(operand_no);
223   if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
224     return InternalError(
225         "Expected operand %d to be token-shaped, actual shape is "
226         "%s:\n%s",
227         operand_no, StringifyShape(token->shape()), instruction->ToString());
228   }
229   return Status::OK();
230 }
231 
CheckOperandAndParameter(const HloInstruction * instruction,int64 operand_number,const HloComputation * computation,int64 parameter_number)232 Status ShapeVerifier::CheckOperandAndParameter(
233     const HloInstruction* instruction, int64 operand_number,
234     const HloComputation* computation, int64 parameter_number) {
235   const HloInstruction* operand = instruction->operand(operand_number);
236   const HloInstruction* parameter =
237       computation->parameter_instruction(parameter_number);
238   if (!ShapesSame(operand->shape(), parameter->shape())) {
239     return InternalError("Operand %s shape does not match parameter's %s in %s",
240                          operand->ToString(), parameter->ToString(),
241                          instruction->ToString());
242   }
243   return Status::OK();
244 }
245 
HandleInfeed(HloInstruction * instruction)246 Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
247   HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
248   TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
249 
250   // The output of infeed is a tuple containing the data value and a token.
251   return CheckShape(infeed,
252                     ShapeUtil::MakeTupleShape(
253                         {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()}));
254 }
255 
HandleOutfeed(HloInstruction * instruction)256 Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
257   HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
258   TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
259 
260   // Outfeed has a separate shape field for the value which is outfed to the
261   // host. The shape of the instruction itself is always a token.
262   if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
263     return InternalError(
264         "Expected outfeed shape to be equal to operand's shape %s, "
265         "actual shape is %s:\n%s",
266         StringifyShape(outfeed->operand(0)->shape()),
267         StringifyShape(outfeed->outfeed_shape()), outfeed->ToString());
268   }
269   return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
270 }
271 
HasCompatibleElementTypes(const Shape & shape_0,const Shape & shape_1,const Shape & result_shape)272 bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
273                                               const Shape& shape_1,
274                                               const Shape& result_shape) {
275   return ShapeUtil::SameElementType(shape_0, shape_1) &&
276          (ShapeUtil::SameElementType(shape_0, result_shape) ||
277           (allow_mixed_precision_ &&
278            ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0,
279                                                          result_shape)));
280 }
281 
HandleRng(HloInstruction * instruction)282 Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
283   TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
284 
285   const Shape& shape_0 = instruction->operand(0)->shape();
286   const Shape& shape_1 = instruction->operand(1)->shape();
287   if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
288     return InternalError(
289         "Expected scalar types for the two operands of Rng instruction: %s",
290         instruction->ToString());
291   }
292 
293   if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
294     return InternalError(
295         "Expected compatible element types for the result and the two operands"
296         " of Rng instruction: %s",
297         instruction->ToString());
298   }
299 
300   PrimitiveType element_type = shape_0.element_type();
301   switch (instruction->random_distribution()) {
302     case RNG_UNIFORM:
303       if (!primitive_util::IsFloatingPointType(element_type) &&
304           !primitive_util::IsIntegralType(element_type) &&
305           element_type != PRED) {
306         return InternalError(
307             "Element type not supported."
308             " Expected element to be of floating point type, integral type or"
309             " predicate type for RngUniform: %s",
310             instruction->ToString());
311       }
312       break;
313 
314     case RNG_NORMAL:
315       if (!primitive_util::IsFloatingPointType(element_type)) {
316         return InternalError(
317             "Element type not supported."
318             " Expected element to be FloatingPointType for RngNormal: %s",
319             instruction->ToString());
320       }
321       break;
322     default:
323       return InternalError(
324           "Invalid Rng distribution %s",
325           RandomDistribution_Name(instruction->random_distribution()));
326   }
327 
328   return Status::OK();
329 }
330 
HandleReverse(HloInstruction * reverse)331 Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
332   return CheckShape(
333       reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
334                                                  reverse->dimensions()));
335 }
336 
HandleSort(HloInstruction * sort)337 Status ShapeVerifier::HandleSort(HloInstruction* sort) {
338   if (sort->operand_count() < 1) {
339     return InternalError("Expected at least 1 operand for %s instruction: %s",
340                          HloOpcodeString(sort->opcode()), sort->ToString());
341   }
342   HloComputation* compare = sort->to_apply();
343 
344   // Check that the 'compare' computation returns a PRED.
345   Shape compare_shape = compare->root_instruction()->shape();
346   if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
347     return InternalError(
348         "The Sort compare computation shape does not lead to a scalar "
349         "predicate shape: %s",
350         StringifyShape(compare_shape));
351   }
352 
353   // Check that the number of parameters of the 'compare' computation is
354   // correct.
355   TF_RETURN_IF_ERROR(
356       CheckParameterCount(sort, compare, sort->operand_count() * 2));
357 
358   // Verify that the operands of the compare computation have the correct scalar
359   // shapes.
360   for (int64 parameter_idx = 0; parameter_idx < compare->num_parameters();
361        ++parameter_idx) {
362     int64 operand_idx = parameter_idx / 2;
363     Shape expected_scalar_shape = ShapeUtil::MakeShape(
364         sort->operand(operand_idx)->shape().element_type(), {});
365     Shape actual_parameter_shape =
366         compare->parameter_instruction(parameter_idx)->shape();
367     if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape,
368                                                   actual_parameter_shape)) {
369       return InternalError(
370           "Expected the %lld-th parameter of the compare computation of sort "
371           "to have shape %s, but got %s",
372           parameter_idx, StringifyShape(expected_scalar_shape),
373           StringifyShape(actual_parameter_shape));
374     }
375   }
376 
377   // Verify that all operand shapes have the same dimensions.
378   for (int64 operand = 1; operand < sort->operand_count(); ++operand) {
379     if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
380                                    sort->operand(operand)->shape())) {
381       return InternalError(
382           "Expected sort to have to have the same dimensions for all operands. "
383           "First operand shape is: %s\n, shape (operand index %lld) is: %s",
384           StringifyShape(sort->operand(0)->shape()), operand,
385           StringifyShape(sort->operand(operand)->shape()));
386     }
387   }
388   return CheckVariadicShape(sort);
389 }
390 
HandleConstant(HloInstruction * constant)391 Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
392   if (!Cast<HloConstantInstruction>(constant)->HasLiteral()) {
393     return InternalError("Constant is required to have a valid literal: %s",
394                          constant->ToString());
395   }
396   return CheckShape(constant, constant->literal().shape());
397 }
398 
HandleIota(HloInstruction * instruction)399 Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
400   auto* iota = Cast<HloIotaInstruction>(instruction);
401   if (!iota->shape().IsArray()) {
402     return InternalError("Iota does not support non-array result.");
403   }
404   const int64 rank = iota->shape().rank();
405   if (rank == 0) {
406     return InternalError("Iota does not support scalars.");
407   }
408   int64 iota_dimension = iota->iota_dimension();
409   if (iota_dimension >= rank) {
410     return InternalError(
411         "The iota dimension cannot go beyond the operation rank.");
412   }
413   return Status::OK();
414 }
415 
HandleGetTupleElement(HloInstruction * get_tuple_element)416 Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
417   return CheckShape(get_tuple_element,
418                     ShapeInference::InferGetTupleElementShape(
419                         get_tuple_element->operand(0)->shape(),
420                         get_tuple_element->tuple_index()));
421 }
422 
423 namespace {
SameElementTypesForOperandsAndToApplyParameters(const HloInstruction & instruction,int64 num_operands_to_check)424 Status SameElementTypesForOperandsAndToApplyParameters(
425     const HloInstruction& instruction, int64 num_operands_to_check) {
426   const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape();
427   for (int i = 0; i < num_operands_to_check; ++i) {
428     const Shape& parameter_shape = to_apply.parameters(i);
429     const Shape& operand_shape = instruction.operands()[i]->shape();
430     if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) {
431       return InvalidArgument(
432           "Shape mismatch between to_apply computation"
433           " parameter and operand %d in %s.",
434           i, instruction.ToString().c_str());
435     }
436   }
437   return Status::OK();
438 }
439 }  // namespace
440 
HandleReduce(HloInstruction * reduce)441 Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
442   if (reduce->operand_count() % 2 != 0) {
443     return InternalError(
444         "Expected an even number of operands for %s instruction: %s",
445         HloOpcodeString(reduce->opcode()), reduce->ToString());
446   }
447 
448   std::vector<const Shape*> operand_shapes;
449   for (const HloInstruction* operand : reduce->operands()) {
450     operand_shapes.push_back(&operand->shape());
451   }
452   TF_RETURN_IF_ERROR(
453       CheckShape(reduce, ShapeInference::InferReduceShape(
454                              operand_shapes, reduce->dimensions(),
455                              reduce->to_apply()->ComputeProgramShape())));
456 
457   return allow_mixed_precision_
458              ? Status::OK()
459              : SameElementTypesForOperandsAndToApplyParameters(
460                    *reduce, reduce->operands().size() - 1);
461 }
462 
HandleBitcast(HloInstruction * bitcast)463 Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
464   // Bitcasts are not allowed to change the element type.
465   if (bitcast->operand(0)->shape().element_type() !=
466       bitcast->shape().element_type()) {
467     return InternalError(
468         "Bitcast can not change the element type from %s to %s",
469         PrimitiveType_Name(bitcast->operand(0)->shape().element_type()),
470         PrimitiveType_Name(bitcast->shape().element_type()));
471   }
472   return Status::OK();
473 }
474 
HandleBroadcast(HloInstruction * broadcast)475 Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
476   // HLO broadcast has no exact analog at the proto level so there is no
477   // ShapeInference method. Check the output shape explicitly.
478   const Shape& operand_shape = broadcast->operand(0)->shape();
479   // Check for mixed precision.
480   TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape));
481   TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size());
482   for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank();
483        ++operand_dimension) {
484     int64 output_dimension = broadcast->dimensions()[operand_dimension];
485     TF_RET_CHECK((output_dimension < broadcast->shape().rank()) &&
486                  output_dimension >= 0 &&
487                  (broadcast->shape().dimensions(output_dimension) ==
488                   operand_shape.dimensions(operand_dimension)))
489         << broadcast->ToString() << " operand shape " << operand_shape;
490   }
491   return Status::OK();
492 }
493 
HandleReshape(HloInstruction * reshape)494 Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
495   // Check for mixed precision.
496   const Shape& operand_shape = reshape->operand(0)->shape();
497   TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape));
498   TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
499                ShapeUtil::ElementsIn(operand_shape));
500   return Status::OK();
501 }
502 
HandleTranspose(HloInstruction * transpose)503 Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
504   return CheckShape(
505       transpose, ShapeInference::InferTransposeShape(
506                      transpose->operand(0)->shape(), transpose->dimensions()));
507 }
508 
HandleParameter(HloInstruction * hlo)509 Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
510   return Status::OK();
511 }
512 
HandleFusion(HloInstruction * fusion)513 Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
514   auto& fused_parameters = fusion->fused_parameters();
515   if (fused_parameters.size() != fusion->operand_count()) {
516     return InternalError(
517         "Fused parameter count (%d) does not match the number of operands (%d)"
518         " passed to the fusion instruction in: %s.",
519         fused_parameters.size(), fusion->operand_count(),
520         fusion->ToString().c_str());
521   }
522   for (HloInstruction* fused_param : fused_parameters) {
523     int64 param_no = fused_param->parameter_number();
524     if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
525       return InternalError(
526           "Shape mismatch between parameter number %d and its operand in "
527           "%s.",
528           param_no, fusion->ToString().c_str());
529     }
530   }
531   return Status::OK();
532 }
533 
HandleCall(HloInstruction * call)534 Status ShapeVerifier::HandleCall(HloInstruction* call) {
535   TF_RETURN_IF_ERROR(
536       CheckParameterCount(call, call->to_apply(), call->operand_count()));
537   for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
538     TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
539   }
540   // The shape of kCall should match the shape of the computation it calls.
541   return CheckShape(call, call->to_apply()->root_instruction()->shape());
542 }
543 
HandleCustomCall(HloInstruction * instruction)544 Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
545   const HloCustomCallInstruction* custom_call =
546       DynCast<const HloCustomCallInstruction>(instruction);
547   TF_RET_CHECK(custom_call != nullptr);
548   if (custom_call->layout_constrained()) {
549     // If the layout is constrained, verify all the respective shapes have
550     // layouts and that the constrained operand shapes match the shapes of the
551     // operands.
552     TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape()));
553     TF_RET_CHECK(custom_call->operand_count() ==
554                  custom_call->operand_shapes_with_layout().size());
555     for (int64 i = 0; i < custom_call->operand_count(); ++i) {
556       const Shape& operand_shape_with_layout =
557           custom_call->operand_shapes_with_layout()[i];
558       TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
559                                          operand_shape_with_layout))
560           << custom_call->operand(i)->shape().ToString() << " operand "
561           << operand_shape_with_layout.ToString();
562       TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
563     }
564   }
565   return Status::OK();
566 }
567 
HandleSlice(HloInstruction * slice)568 Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
569   return CheckShape(slice,
570                     ShapeInference::InferSliceShape(
571                         slice->operand(0)->shape(), slice->slice_starts(),
572                         slice->slice_limits(), slice->slice_strides()));
573 }
574 
HandleDynamicSlice(HloInstruction * dynamic_slice)575 Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
576   return CheckShape(
577       dynamic_slice,
578       ShapeInference::InferDynamicSliceShape(
579           dynamic_slice->operand(0)->shape(),
580           Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
581           dynamic_slice->dynamic_slice_sizes()));
582 }
583 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)584 Status ShapeVerifier::HandleDynamicUpdateSlice(
585     HloInstruction* dynamic_update_slice) {
586   return CheckShape(
587       dynamic_update_slice,
588       ShapeInference::InferDynamicUpdateSliceShape(
589           dynamic_update_slice->operand(0)->shape(),
590           dynamic_update_slice->operand(1)->shape(),
591           Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
592               ->index_shapes()));
593 }
594 
HandleTuple(HloInstruction * tuple)595 Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
596   return CheckVariadicShape(tuple);
597 }
598 
HandleMap(HloInstruction * map)599 Status ShapeVerifier::HandleMap(HloInstruction* map) {
600   std::vector<const Shape*> operand_shapes;
601   int64 max_operand_rank = 0;
602   for (const HloInstruction* operand : map->operands()) {
603     operand_shapes.push_back(&operand->shape());
604     max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
605   }
606   // TODO(b/65689298) Remove code below once Map is generalized to accept
607   // arbitrary map dimensions.
608   std::vector<int64> map_dims(max_operand_rank);
609   std::iota(map_dims.begin(), map_dims.end(), 0);
610 
611   TF_RETURN_IF_ERROR(CheckShape(
612       map,
613       ShapeInference::InferMapShape(
614           operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)));
615 
616   return allow_mixed_precision_
617              ? Status::OK()
618              : SameElementTypesForOperandsAndToApplyParameters(
619                    *map, map->operands().size());
620 }
621 
HandleReduceWindow(HloInstruction * reduce_window)622 Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
623   TF_RETURN_IF_ERROR(CheckShape(
624       reduce_window,
625       ShapeInference::InferReduceWindowShape(
626           reduce_window->operand(0)->shape(),
627           reduce_window->operand(1)->shape(), reduce_window->window(),
628           reduce_window->to_apply()->ComputeProgramShape())));
629 
630   return allow_mixed_precision_
631              ? Status::OK()
632              : SameElementTypesForOperandsAndToApplyParameters(*reduce_window,
633                                                                1);
634 }
635 
HandleSelectAndScatter(HloInstruction * instruction)636 Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
637   return CheckShape(
638       instruction,
639       ShapeInference::InferSelectAndScatterShape(
640           instruction->operand(0)->shape(),
641           instruction->select()->ComputeProgramShape(), instruction->window(),
642           instruction->operand(1)->shape(), instruction->operand(2)->shape(),
643           instruction->scatter()->ComputeProgramShape()));
644 }
645 
HandleWhile(HloInstruction * xla_while)646 Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
647   TF_RETURN_IF_ERROR(
648       CheckParameterCount(xla_while, xla_while->while_body(), 1));
649   TF_RETURN_IF_ERROR(
650       CheckParameterCount(xla_while, xla_while->while_condition(), 1));
651   TF_RETURN_IF_ERROR(
652       CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
653   TF_RETURN_IF_ERROR(
654       CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
655   const Shape& conditional_shape =
656       xla_while->while_condition()->root_instruction()->shape();
657   if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) {
658     return InternalError(
659         "Conditional computation shape does not lead to a scalar predicate "
660         "shape: %s",
661         StringifyShape(conditional_shape));
662   }
663   // The shape of kWhile should match the shape of the body computation it
664   // calls.
665   return CheckShape(xla_while,
666                     xla_while->while_body()->root_instruction()->shape());
667 }
668 
HandleConditional(HloInstruction * conditional)669 Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
670   const int num_branches = conditional->branch_count();
671   if (conditional->operand(0)->shape().element_type() == PRED) {
672     TF_RET_CHECK(num_branches == 2);
673   } else {
674     TF_RET_CHECK(num_branches >= 1);
675   }
676   TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1));
677   for (int j = 0; j < num_branches; ++j) {
678     TF_RETURN_IF_ERROR(CheckParameterCount(
679         conditional, conditional->branch_computation(j), 1));
680     TF_RETURN_IF_ERROR(CheckOperandAndParameter(
681         conditional, j + 1, conditional->branch_computation(j), 0));
682     TF_RETURN_IF_ERROR(CheckShape(
683         conditional,
684         conditional->branch_computation(j)->root_instruction()->shape()));
685   }
686   return Status::OK();
687 }
688 
HandlePad(HloInstruction * pad)689 Status ShapeVerifier::HandlePad(HloInstruction* pad) {
690   return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
691                                                        pad->operand(1)->shape(),
692                                                        pad->padding_config()));
693 }
694 
HandleSend(HloInstruction * send)695 Status ShapeVerifier::HandleSend(HloInstruction* send) {
696   return CheckShape(send,
697                     ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
698                                                ShapeUtil::MakeShape(U32, {}),
699                                                ShapeUtil::MakeTokenShape()}));
700 }
701 
HandleSendDone(HloInstruction * send_done)702 Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
703   return CheckShape(send_done, ShapeUtil::MakeTokenShape());
704 }
705 
HandleRecv(HloInstruction * recv)706 Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
707   return CheckShape(
708       recv, ShapeUtil::MakeTupleShape(
709                 {ShapeUtil::GetTupleElementShape(recv->shape(), 0),
710                  ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}));
711 }
712 
HandleRecvDone(HloInstruction * recv_done)713 Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
714   return CheckShape(
715       recv_done,
716       ShapeUtil::MakeTupleShape(
717           {ShapeUtil::GetTupleElementShape(recv_done->operand(0)->shape(), 0),
718            ShapeUtil::MakeTokenShape()}));
719 }
720 
HandleBatchNormTraining(HloInstruction * batch_norm_training)721 Status ShapeVerifier::HandleBatchNormTraining(
722     HloInstruction* batch_norm_training) {
723   return CheckShape(batch_norm_training,
724                     ShapeInference::InferBatchNormTrainingShape(
725                         batch_norm_training->operand(0)->shape(),
726                         batch_norm_training->operand(1)->shape(),
727                         batch_norm_training->operand(2)->shape(),
728                         batch_norm_training->feature_index()));
729 }
730 
HandleBatchNormInference(HloInstruction * batch_norm_inference)731 Status ShapeVerifier::HandleBatchNormInference(
732     HloInstruction* batch_norm_inference) {
733   return CheckShape(batch_norm_inference,
734                     ShapeInference::InferBatchNormInferenceShape(
735                         batch_norm_inference->operand(0)->shape(),
736                         batch_norm_inference->operand(1)->shape(),
737                         batch_norm_inference->operand(2)->shape(),
738                         batch_norm_inference->operand(3)->shape(),
739                         batch_norm_inference->operand(4)->shape(),
740                         batch_norm_inference->feature_index()));
741 }
742 
HandleBatchNormGrad(HloInstruction * batch_norm_grad)743 Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
744   return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
745                                          batch_norm_grad->operand(0)->shape(),
746                                          batch_norm_grad->operand(1)->shape(),
747                                          batch_norm_grad->operand(2)->shape(),
748                                          batch_norm_grad->operand(3)->shape(),
749                                          batch_norm_grad->operand(4)->shape(),
750                                          batch_norm_grad->feature_index()));
751 }
752 
753 namespace {
754 
755 // Checks that the instruction does not have mixed precision floating point
756 // inputs.
CheckMixedPrecisionOperands(const HloInstruction * instruction)757 Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
758   switch (instruction->opcode()) {
759     // White list the following opcodes for mixed-precision check, because
760     // they involve data pass through or grouping via tuples, where the
761     // precisions of buffers can be different.
762     case HloOpcode::kCall:
763     case HloOpcode::kConditional:
764     case HloOpcode::kConstant:
765     case HloOpcode::kAllReduce:
766     case HloOpcode::kCustomCall:
767     case HloOpcode::kDomain:
768     case HloOpcode::kFusion:
769     case HloOpcode::kGetTupleElement:
770     case HloOpcode::kInfeed:
771     case HloOpcode::kOutfeed:
772     case HloOpcode::kParameter:
773     case HloOpcode::kRecv:
774     case HloOpcode::kRecvDone:
775     case HloOpcode::kReducePrecision:
776     case HloOpcode::kTupleSelect:
777     case HloOpcode::kSend:
778     case HloOpcode::kSendDone:
779     case HloOpcode::kSort:
780     case HloOpcode::kTuple:
781     case HloOpcode::kWhile:
782       break;
783     default: {
784       PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
785       for (auto operand : instruction->operands()) {
786         TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
787             operand->shape(),
788             [&](const Shape& subshape, const ShapeIndex& index) {
789               if (!ShapeUtil::ElementIsFloating(subshape)) {
790                 return Status::OK();
791               }
792               if (fp_type == PRIMITIVE_TYPE_INVALID) {
793                 fp_type = subshape.element_type();
794               } else if (fp_type != subshape.element_type()) {
795                 return InternalError(
796                     "Seen floating point types of different precisions in "
797                     "%s, but mixed precision is disallowed.",
798                     instruction->ToString());
799               }
800               return Status::OK();
801             }));
802       }
803     }
804   }
805   return Status::OK();
806 }
807 
808 }  // namespace
809 
HandleGather(HloInstruction * gather)810 Status ShapeVerifier::HandleGather(HloInstruction* gather) {
811   return CheckShape(
812       gather,
813       ShapeInference::InferGatherShape(
814           gather->operand(0)->shape(), gather->operand(1)->shape(),
815           gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
816 }
817 
HandleScatter(HloInstruction * scatter)818 Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
819   return CheckShape(
820       scatter, ShapeInference::InferScatterShape(
821                    scatter->operand(0)->shape(), scatter->operand(1)->shape(),
822                    scatter->operand(2)->shape(),
823                    scatter->to_apply()->ComputeProgramShape(),
824                    scatter->scatter_dimension_numbers()));
825 }
826 
HandleAfterAll(HloInstruction * token)827 Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
828   std::vector<const Shape*> operand_shapes;
829   for (const HloInstruction* operand : token->operands()) {
830     operand_shapes.push_back(&operand->shape());
831   }
832   return CheckShape(token, ShapeUtil::MakeTokenShape());
833 }
834 
HandleAddDependency(HloInstruction * add_dependency)835 Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) {
836   TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1));
837   return CheckShape(add_dependency, add_dependency->operand(0)->shape());
838 }
839 
HandleGetDimensionSize(HloInstruction * get_size)840 Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
841   return CheckShape(get_size,
842                     ShapeInference::InferGetDimensionSizeShape(
843                         get_size->operand(0)->shape(), get_size->dimension()));
844 }
845 
CheckShape(const HloInstruction * instruction,const Shape & inferred_shape)846 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
847                                  const Shape& inferred_shape) {
848   // If allow_mixed_precision_ is false, check if there are operands with
849   // different precisions. We need this check because ShapeInference allows
850   // mixed precision inputs.
851   if (!allow_mixed_precision_) {
852     TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
853   }
854 
855   // Check if the output shape matches the expected shape.
856   //
857   // We treat BF16 and F32 as compatible types if mixed precision is allowed,
858   // but only when the instruction defines the BF16/F32 buffer.
859   bool equal = [&] {
860     switch (instruction->opcode()) {
861       // The opcodes below can't have implicit layout conversions, nor can they
862       // implicitly transform f32 -> bf16.  Fundamentally these are either
863       // reinterpreting existing data (e.g. kBitcast) or shuffling data around
864       // without modifying it (e.g. kGetTupleElement, kTupleSelect).
865       case HloOpcode::kBitcast:
866       case HloOpcode::kCall:
867       case HloOpcode::kConditional:
868       case HloOpcode::kConstant:
869       case HloOpcode::kCustomCall:
870       case HloOpcode::kGetTupleElement:
871       case HloOpcode::kInfeed:
872       case HloOpcode::kOutfeed:
873       case HloOpcode::kParameter:
874       case HloOpcode::kRecv:
875       case HloOpcode::kRecvDone:
876       case HloOpcode::kSend:
877       case HloOpcode::kSendDone:
878       case HloOpcode::kTuple:
879       case HloOpcode::kTupleSelect:
880       case HloOpcode::kWhile:
881         return ShapesSame(instruction->shape(), inferred_shape);
882 
883       // We allow arbitrary layout and f32->bf16 transformations on all other
884       // instructions, although this may be made more strict pending discussion
885       // in b/112709536.
886       default:
887         if (allow_mixed_precision_) {
888           return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
889                                                           inferred_shape);
890         } else {
891           return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
892         }
893     }
894   }();
895   if (!equal) {
896     return InternalError(
897         "Expected instruction to have shape equal to %s, actual "
898         "shape is %s:\n%s",
899         StringifyShape(inferred_shape), StringifyShape(instruction->shape()),
900         instruction->ToString());
901   }
902   return Status::OK();
903 }
904 
CheckShape(const HloInstruction * instruction,const StatusOr<Shape> & inferred_shape_status)905 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
906                                  const StatusOr<Shape>& inferred_shape_status) {
907   if (!inferred_shape_status.ok()) {
908     Status s = inferred_shape_status.status();
909     tensorflow::errors::AppendToMessage(&s, ", for instruction ",
910                                         instruction->ToString());
911     return s;
912   }
913   return CheckShape(instruction, inferred_shape_status.ValueOrDie());
914 }
915 
CheckUnaryShape(const HloInstruction * instruction)916 Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
917   return CheckShape(instruction,
918                     ShapeInference::InferUnaryOpShape(instruction->opcode(),
919                                                       instruction->operand(0)));
920 }
921 
CheckBinaryShape(const HloInstruction * instruction)922 Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
923   return CheckShape(
924       instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
925                                                       instruction->operand(0),
926                                                       instruction->operand(1)));
927 }
928 
CheckTernaryShape(const HloInstruction * instruction)929 Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
930   return CheckShape(instruction,
931                     ShapeInference::InferTernaryOpShape(
932                         instruction->opcode(), instruction->operand(0),
933                         instruction->operand(1), instruction->operand(2)));
934 }
935 
CheckVariadicShape(const HloInstruction * instruction)936 Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
937   return CheckShape(instruction,
938                     ShapeInference::InferVariadicOpShape(
939                         instruction->opcode(), instruction->operands()));
940 }
941 
VerifyEntryComputationLayout(const HloModule & module)942 Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
943   const HloComputation* computation = module.entry_computation();
944   const auto& layout = module.entry_computation_layout();
945   const ShapeLayout& result_layout = layout.result_layout();
946 
947   TF_RETURN_IF_ERROR(
948       ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));
949 
950   TF_RETURN_IF_ERROR(VerifyNotSparse(result_layout.shape()));
951 
952   if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
953                              result_layout.shape())) {
954     return InternalError(
955         "Shape of the root instruction of entry computation (%s) should be "
956         "compatible to one specified in module's entry computation layout (%s)",
957         ShapeUtil::HumanString(computation->root_instruction()->shape()),
958         ShapeUtil::HumanString(result_layout.shape()));
959   }
960 
961   if (computation->num_parameters() != layout.parameter_count()) {
962     return InternalError(
963         "Number of parameters in entry computation layout (%d) must be same "
964         "as number of parameters of entry computation computation (%d)",
965         layout.parameter_count(), computation->num_parameters());
966   }
967 
968   for (int i = 0; i < computation->num_parameters(); ++i) {
969     const HloInstruction* parameter = computation->parameter_instruction(i);
970     TF_RETURN_IF_ERROR(
971         ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
972     TF_RETURN_IF_ERROR(VerifyNotSparse(layout.parameter_shape(i)));
973     if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
974       return InternalError(
975           "Shape of the entry computation parameter %d is %s should be "
976           "compatible to the one specified in module's entry computation "
977           "layout %s",
978           i, ShapeUtil::HumanString(parameter->shape()),
979           ShapeUtil::HumanString(layout.parameter_shape(i)));
980     }
981   }
982 
983   return Status::OK();
984 }
985 
ComputationsToString(absl::Span<HloComputation * const> computations)986 string ComputationsToString(absl::Span<HloComputation* const> computations) {
987   return absl::StrJoin(computations, ",",
988                        [](string* s, const HloComputation* computation) {
989                          s->append(computation->name());
990                        });
991 }
992 
993 // Verifies various invariants about the structure of the HLO:
994 //
995 // (1) each instruction has a non-null parent() set to the HloComputation
996 // which
997 //     contains it.
998 //
999 // (2) each computation has a non-null parent() set to the HloModule which
1000 //     contains it.
1001 //
1002 // (3) the operands of each instruction are in the same computation as the
1003 //     instruction.
VerifyHloStructure(HloModule * module)1004 Status VerifyHloStructure(HloModule* module) {
1005   for (const HloComputation* computation : module->computations()) {
1006     if (computation->parent() == nullptr) {
1007       return InternalError("Computation %s has a null parent pointer",
1008                            computation->name());
1009     }
1010     if (computation->parent() != module) {
1011       return InternalError(
1012           "Computation %s parent() does not point to parent module",
1013           computation->name());
1014     }
1015 
1016     for (const HloInstruction* instruction : computation->instructions()) {
1017       if (instruction->parent() == nullptr) {
1018         return InternalError("Instruction %s has a null parent pointer",
1019                              instruction->name());
1020       }
1021       if (instruction->parent() != computation) {
1022         return InternalError(
1023             "Instruction %s parent() does not point to parent computation",
1024             instruction->name());
1025       }
1026     }
1027   }
1028 
1029   // Check that operands are in the same computation separately from verifying
1030   // parent() correctness so conditions like a null HloInstruction::parent()
1031   // are identified and reported explicitly above rather than reporting a
1032   // mismatched operand.
1033   for (const HloComputation* computation : module->computations()) {
1034     for (const HloInstruction* instruction : computation->instructions()) {
1035       for (int i = 0; i < instruction->operand_count(); ++i) {
1036         const HloInstruction* operand = instruction->operand(i);
1037         if (operand->parent() != instruction->parent()) {
1038           return InternalError(
1039               "Operand %d (%s) of instruction %s is in a different "
1040               "computation: %s vs %s",
1041               i, operand->name(), instruction->name(),
1042               operand->parent()->name(), instruction->parent()->name());
1043         }
1044       }
1045     }
1046   }
1047   return Status::OK();
1048 }
1049 
1050 namespace {
1051 
1052 // Returns true if the given Shape has a TOKEN shape as any subshape.
ShapeContainsToken(const Shape & shape)1053 bool ShapeContainsToken(const Shape& shape) {
1054   bool contains_token = false;
1055   ShapeUtil::ForEachSubshape(
1056       shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
1057         if (subshape.IsToken()) {
1058           contains_token = true;
1059         }
1060       });
1061   return contains_token;
1062 }
1063 
1064 // Verifies that all types entering and exiting the entry computation are
1065 // legal.
VerifyEntryAndExitShapes(const HloModule & module)1066 Status VerifyEntryAndExitShapes(const HloModule& module) {
1067   // Tokens cannot be passed as entry parameters.
1068   // TODO(b/80000000): Remove this constraint.
1069   for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
1070     HloInstruction* param =
1071         module.entry_computation()->parameter_instruction(i);
1072     if (ShapeContainsToken(param->shape())) {
1073       return InternalError(
1074           "Entry parameter %d is or contains a token shape: %s", i,
1075           ShapeUtil::HumanString(param->shape()));
1076     }
1077   }
1078   return Status::OK();
1079 }
1080 
1081 // Checks if the given two instructions share the same channel id.
CheckSameChannel(const HloInstruction * instr1,const HloInstruction * instr2)1082 Status CheckSameChannel(const HloInstruction* instr1,
1083                         const HloInstruction* instr2) {
1084   if (instr1->channel_id() != instr2->channel_id()) {
1085     return InternalError(
1086         "Expected to have the same channel id, actual channel ids are: %s "
1087         "(%d), %s (%d)",
1088         instr1->ToString(), instr1->channel_id(), instr2->ToString(),
1089         instr2->channel_id());
1090   }
1091   return Status::OK();
1092 }
1093 
1094 // Checks if the given two instructions have the same is_host_transfer
1095 // attribute value. Intsructions must be send/recv instructions or their
1096 // 'done' variant.
CheckSameIsHostTransfer(const HloInstruction * instr1,const HloInstruction * instr2)1097 Status CheckSameIsHostTransfer(const HloInstruction* instr1,
1098                                const HloInstruction* instr2) {
1099   const HloSendRecvInstruction* send_recv1 =
1100       DynCast<const HloSendRecvInstruction>(instr1);
1101   const HloSendRecvInstruction* send_recv2 =
1102       DynCast<const HloSendRecvInstruction>(instr2);
1103   TF_RET_CHECK(send_recv1 != nullptr);
1104   TF_RET_CHECK(send_recv2 != nullptr);
1105   if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
1106     return InternalError(
1107         "Expected instructions to have the same is-host-transfer property: "
1108         "%s, "
1109         "%s ",
1110         instr1->ToString(), instr2->ToString());
1111   }
1112   return Status::OK();
1113 }
1114 
1115 // Checks various invariants of send and recv instructions.
VerifySendsAndRecvs(const HloModule & module)1116 Status VerifySendsAndRecvs(const HloModule& module) {
1117   absl::flat_hash_map<int64, const HloInstruction*> host_channels;
1118   // Host send/recv instructions must have their own unique channel.
1119   auto check_unique_host_channel = [&](const HloInstruction* instruction) {
1120     const HloSendRecvInstruction* sendrecv =
1121         DynCast<const HloSendRecvInstruction>(instruction);
1122     if (sendrecv->is_host_transfer()) {
1123       auto it_inserted =
1124           host_channels.insert({sendrecv->channel_id(), sendrecv});
1125       if (!it_inserted.second) {
1126         return FailedPrecondition(
1127             "Channel %d is used for multiple host send/recv instructions: "
1128             "%s "
1129             "and "
1130             "%s",
1131             sendrecv->channel_id(), sendrecv->ToString(),
1132             it_inserted.first->second->ToString());
1133       }
1134     }
1135 
1136     return Status::OK();
1137   };
1138 
1139   // Send/Recv instruction must have a single user: the corresponding
1140   // SendDone/RecvDone. with matching channel.
1141   for (const HloComputation* computation : module.computations()) {
1142     for (const HloInstruction* instruction : computation->instructions()) {
1143       switch (instruction->opcode()) {
1144         case HloOpcode::kSend: {
1145           TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
1146           TF_RET_CHECK(instruction->users().size() == 1);
1147           const HloInstruction* send_done = instruction->users().front();
1148           TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
1149           TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
1150           TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
1151           break;
1152         }
1153         case HloOpcode::kRecv: {
1154           TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
1155           TF_RET_CHECK(instruction->users().size() == 1);
1156           const HloInstruction* recv_done = instruction->users().front();
1157           TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
1158           TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
1159           TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
1160           break;
1161         }
1162         case HloOpcode::kSendDone:
1163           TF_RET_CHECK(instruction->operands().size() == 1);
1164           TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
1165           break;
1166         case HloOpcode::kRecvDone:
1167           TF_RET_CHECK(instruction->operands().size() == 1);
1168           TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
1169           break;
1170         default:
1171           break;
1172       }
1173     }
1174   }
1175   return Status::OK();
1176 }
1177 
1178 // CHECKs various invariants of a fusion instruction.
CheckFusionInstruction(HloInstruction * fusion)1179 Status CheckFusionInstruction(HloInstruction* fusion) {
1180   // The parent fusion instruction of the fusion computation must be 'fusion'.
1181   HloComputation* fused_computation = fusion->fused_instructions_computation();
1182   if (fusion != fused_computation->FusionInstruction()) {
1183     return InternalError(
1184         "Instruction of fused computation does not match expected "
1185         "instruction "
1186         "%s.",
1187         fusion->ToString());
1188   }
1189 
1190   // Fused root instruction and fused parameters must all be owned by the
1191   // fusion computation.
1192   bool root_owned = false;
1193   const std::vector<HloInstruction*>& fused_parameters =
1194       fusion->fused_parameters();
1195   const HloInstruction* fused_root = fusion->fused_expression_root();
1196   std::vector<bool> parameter_owned(fused_parameters.size(), false);
1197   for (auto* instruction : fused_computation->instructions()) {
1198     if (fused_root == instruction) {
1199       if (root_owned) {
1200         return InternalError("Root appears more than once in %s.",
1201                              fusion->ToString());
1202       }
1203       root_owned = true;
1204     }
1205     for (int i = 0; i < fused_parameters.size(); ++i) {
1206       if (fused_parameters[i] == instruction) {
1207         if (parameter_owned[i]) {
1208           return InternalError("Parameter appears more than once in %s.",
1209                                fusion->ToString());
1210         }
1211         parameter_owned[i] = true;
1212       }
1213     }
1214   }
1215   if (!root_owned) {
1216     return InternalError("Root not found in computation of %s.",
1217                          fusion->ToString());
1218   }
1219   // Make sure all the parameter_owned entries are set
1220   for (int i = 0; i < parameter_owned.size(); i++) {
1221     if (!parameter_owned[i]) {
1222       return InternalError("Parameter %d not found in computation of %s.", i,
1223                            fusion->ToString());
1224     }
1225   }
1226 
1227   // Fused root must have no users.
1228   if (fused_root->user_count() != 0) {
1229     return InternalError("Root of %s may not have users.", fusion->ToString());
1230   }
1231 
1232   // All uses of fused instructions must be in the fusion computation, and
1233   // every non-root instruction must have at least one use.
1234   for (auto* instruction :
1235        fusion->fused_instructions_computation()->instructions()) {
1236     if (instruction != fused_root) {
1237       if (instruction->user_count() == 0) {
1238         return InternalError("Non-root instruction %s in %s must have users.",
1239                              instruction->ToString(), fusion->ToString());
1240       }
1241       for (auto& user : instruction->users()) {
1242         if (fused_computation != user->parent()) {
1243           return InternalError(
1244               "Non-root instruction %s in %s may not have external users.",
1245               instruction->ToString(), fusion->ToString());
1246         }
1247       }
1248     }
1249   }
1250 
1251   // Fused parameter instructions must be numbered contiguously and match up
1252   // (shapes equal) with their respective operand.
1253   CHECK_EQ(fusion->operands().size(), fused_parameters.size());
1254   std::vector<bool> parameter_numbers(fused_parameters.size(), false);
1255   for (auto fused_param : fused_parameters) {
1256     int64 param_no = fused_param->parameter_number();
1257     if (param_no < 0) {
1258       return InternalError("Unexpected negative parameter number %d in %s.",
1259                            param_no, fusion->ToString());
1260     }
1261     if (param_no >= fused_parameters.size()) {
1262       return InternalError(
1263           "Unexpected parameter number %d in %s: higher then number of "
1264           "parameters %lu.",
1265           param_no, fusion->ToString(), fused_parameters.size());
1266     }
1267     if (parameter_numbers[param_no]) {
1268       return InternalError(
1269           "Did not expect parameter number %d more than once in %s.", param_no,
1270           fusion->ToString());
1271     }
1272     parameter_numbers[param_no] = true;
1273   }
1274   // Make sure all the parameter_numbers entries were seen.
1275   for (int i = 0; i < parameter_numbers.size(); i++) {
1276     if (!parameter_numbers[i]) {
1277       return InternalError("Did not see parameter number %d in %s.", i,
1278                            fusion->ToString());
1279     }
1280   }
1281 
1282   TF_RET_CHECK(fusion->called_computations() ==
1283                absl::Span<HloComputation* const>(
1284                    {fusion->fused_instructions_computation()}))
1285       << "Fusion HLO calls computations other than the "
1286          "fused_instructions_computation: "
1287       << fusion->ToString() << " fusion->fused_instructions_computation(): "
1288       << fusion->fused_instructions_computation()->ToString()
1289       << " fusion->called_computations(): "
1290       << ComputationsToString(fusion->called_computations());
1291 
1292   for (const auto& fused : fusion->fused_instructions()) {
1293     TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
1294         << "Fused HLO was missing a parent: " << fused->ToString()
1295         << " parent: " << fused->parent()
1296         << " computation: " << fusion->parent();
1297   }
1298 
1299   // TODO(b/65423525): We'd like to check that all operands are distinct.
1300   // This is currently disabled due to the invariant being violated by
1301   // multi-output fusion.
1302   return Status::OK();
1303 }
1304 
1305 // Checks that the operand shapes are compatible to the output shape, i.e.,
1306 // that there are no implicit broadcasts.
CheckElementwiseInstruction(HloInstruction * instruction)1307 Status CheckElementwiseInstruction(HloInstruction* instruction) {
1308   const Shape& out_shape = instruction->shape();
1309   for (HloInstruction* operand : instruction->operands()) {
1310     const Shape& operand_shape = operand->shape();
1311     if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
1312       return FailedPrecondition(
1313           "Implicit broadcast is not allowed in HLO."
1314           "Found different shapes for instruction %s.\n"
1315           "output: %s\noperand: %s\n",
1316           HloOpcodeString(instruction->opcode()),
1317           ShapeUtil::HumanString(out_shape),
1318           ShapeUtil::HumanString(operand_shape));
1319     }
1320   }
1321   return Status::OK();
1322 }
1323 
1324 // Visitor which verifies various fields on the HLO instruction. This class does
1325 // not check result shape as that is checked in the ShapeVerifier.
1326 class InstructionVerifier : public DfsHloVisitorWithDefault {
1327  public:
InstructionVerifier(std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)1328   explicit InstructionVerifier(std::function<bool(const HloInstruction*)>
1329                                    instruction_can_change_layout_func)
1330       : instruction_can_change_layout_func_(
1331             instruction_can_change_layout_func) {}
1332 
DefaultAction(HloInstruction *)1333   Status DefaultAction(HloInstruction*) override { return Status::OK(); }
1334 
HandleFusion(HloInstruction * fusion)1335   Status HandleFusion(HloInstruction* fusion) override {
1336     return CheckFusionInstruction(fusion);
1337   }
1338 
HandleBroadcast(HloInstruction * broadcast)1339   Status HandleBroadcast(HloInstruction* broadcast) override {
1340     // If you see this failure then someone has confused the difference
1341     // between the HLO broadcast op, and the UserComputation broadcast
1342     // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
1343     // or ComputationLowerer::Visit()
1344     TF_RET_CHECK(broadcast->dimensions().size() ==
1345                  broadcast->operand(0)->shape().rank())
1346         << "Broadcast HLO (" << broadcast->ToShortString()
1347         << ") has invalid number of dimensions: "
1348         << broadcast->dimensions().size()
1349         << " != " << broadcast->operand(0)->shape().rank();
1350     return Status::OK();
1351   }
1352 
HandleWhile(HloInstruction * xla_while)1353   Status HandleWhile(HloInstruction* xla_while) override {
1354     auto* while_cond = xla_while->while_condition();
1355     auto* while_body = xla_while->while_body();
1356     if (while_cond->num_parameters() != 1) {
1357       return FailedPrecondition(
1358           "While condition must have exactly 1 parameter; had %d : %s",
1359           while_cond->num_parameters(), while_cond->ToString());
1360     }
1361     if (while_body->num_parameters() != 1) {
1362       return FailedPrecondition(
1363           "While body must have exactly 1 parameter; had %d : %s",
1364           while_body->num_parameters(), while_body->ToString());
1365     }
1366     if (xla_while->operand_count() != 1) {
1367       return FailedPrecondition(
1368           "While loop must have exactly one operand; had %d : %s",
1369           xla_while->operand_count(), xla_while->ToString());
1370     }
1371     return Status::OK();
1372   }
1373 
HandleConditional(HloInstruction * conditional)1374   Status HandleConditional(HloInstruction* conditional) override {
1375     for (int b = 0; b < conditional->branch_count(); ++b) {
1376       if (conditional->branch_computation(b)->num_parameters() != 1) {
1377         return FailedPrecondition(
1378             "Branch computation %s of %s must have 1 parameter insted of %d",
1379             conditional->branch_computation(b)->name(), conditional->ToString(),
1380             conditional->branch_computation(b)->num_parameters());
1381       }
1382     }
1383     return Status::OK();
1384   }
1385 
HandleElementwiseUnary(HloInstruction * instruction)1386   Status HandleElementwiseUnary(HloInstruction* instruction) override {
1387     return CheckElementwiseInstruction(instruction);
1388   }
1389 
HandleElementwiseBinary(HloInstruction * instruction)1390   Status HandleElementwiseBinary(HloInstruction* instruction) override {
1391     return CheckElementwiseInstruction(instruction);
1392   }
1393 
HandleGetTupleElement(HloInstruction * gte)1394   Status HandleGetTupleElement(HloInstruction* gte) override {
1395     TF_RET_CHECK(gte->operand(0)->shape().IsTuple());
1396     return Status::OK();
1397   }
1398 
HandleTranspose(HloInstruction * transpose)1399   Status HandleTranspose(HloInstruction* transpose) override {
1400     const Shape& shape = transpose->shape();
1401     const HloInstruction* operand = transpose->operand(0);
1402     TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
1403     TF_RET_CHECK(shape.dimensions().size() ==
1404                  transpose->operand(0)->shape().dimensions().size());
1405     TF_RET_CHECK(std::equal(
1406         operand->shape().dimensions().begin(),
1407         operand->shape().dimensions().end(),
1408         Permute(transpose->dimensions(), shape.dimensions()).begin()))
1409         << "shape: " << shape << ", operand->shape(): " << shape
1410         << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
1411         << "}";
1412     return Status::OK();
1413   }
1414 
HandleAllReduce(HloInstruction * crs)1415   Status HandleAllReduce(HloInstruction* crs) override {
1416     if (crs->all_reduce_id().has_value()) {
1417       TF_RET_CHECK(crs->all_reduce_id().value() > 0)
1418           << "All reduce id must be greater than 0 for "
1419           << crs->ToShortString();
1420     }
1421     return Status::OK();
1422   }
1423 
Preprocess(HloInstruction * instruction)1424   Status Preprocess(HloInstruction* instruction) override {
1425     auto previous = instructions_by_name_.find(instruction->name());
1426     TF_RET_CHECK(previous == instructions_by_name_.end())
1427         << "HLO has name that is not unique within module:\n"
1428         << instruction->ToString()
1429         << " in computation: " << instruction->parent()->name()
1430         << "\nPrevious HLO with same name:\n"
1431         << previous->second->ToString()
1432         << " in computation: " << previous->second->parent()->name();
1433     instructions_by_name_[instruction->name()] = instruction;
1434     return Status::OK();
1435   }
1436 
Postprocess(HloInstruction * instruction)1437   Status Postprocess(HloInstruction* instruction) override {
1438     if (instruction_can_change_layout_func_ &&
1439         LayoutUtil::IsDenseArray(instruction->shape()) &&
1440         !instruction_can_change_layout_func_(instruction)) {
1441       const Shape& result_shape = instruction->shape();
1442       const Layout& result_layout = result_shape.layout();
1443       for (HloInstruction* operand : instruction->operands()) {
1444         const Shape& operand_shape = operand->shape();
1445         if (LayoutUtil::IsDenseArray(operand_shape) &&
1446             operand_shape.rank() == result_shape.rank()) {
1447           const Layout& operand_layout = operand_shape.layout();
1448           TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
1449               << "Instruction shouldn't change layouts "
1450               << instruction->ToString() << " From " << result_shape << " To "
1451               << operand_shape;
1452         }
1453       }
1454     }
1455 
1456     return Status::OK();
1457   }
1458 
1459  private:
1460   absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
1461   // Determines whether an instruction can change layouts.
1462   std::function<bool(const HloInstruction*)>
1463       instruction_can_change_layout_func_;
1464 };
1465 
1466 }  // namespace
1467 
Run(HloModule * module)1468 StatusOr<bool> HloVerifier::Run(HloModule* module) {
1469   TF_RET_CHECK(!module->name().empty());
1470 
1471   if (module->entry_computation()->IsFusionComputation()) {
1472     return InvalidArgument(
1473         "Module entry computation cannot be a fusion computation");
1474   }
1475 
1476   TF_RETURN_IF_ERROR(VerifyHloStructure(module));
1477   TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
1478 
1479   std::unique_ptr<ShapeVerifier> shape_verifier =
1480       target_metadata_->GetVerifier();
1481   for (auto* computation : module->computations()) {
1482     TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
1483 
1484     InstructionVerifier instruction_verifier(
1485         instruction_can_change_layout_func_);
1486     TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
1487   }
1488 
1489   TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module));
1490   TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
1491 
1492   // If the module has a schedule, it must be valid.
1493   if (module->has_schedule()) {
1494     TF_RETURN_IF_ERROR(module->schedule().Verify());
1495   }
1496 
1497   TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
1498       *module, [this](const Shape& shape) {
1499         return target_metadata_->ShapeSize(shape);
1500       }));
1501 
1502   TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
1503 
1504   return false;
1505 }
1506 
1507 }  // namespace xla
1508