1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/layout_assignment.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla.pb.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 
35 namespace xla {
36 namespace {
37 
38 using ::testing::HasSubstr;
39 
CreateUnverifiedModule()40 std::unique_ptr<HloModule> CreateUnverifiedModule() {
41   return absl::make_unique<HloModule>("module", HloModuleConfig());
42 }
43 
44 // This class cannot be converted to use HloTestBase. It explicitly
45 // uses HloTestBase to create and test malformed HLOs.
46 class HloVerifierTest : public HloTestBase {
47  public:
HloVerifierTest()48   HloVerifierTest()
49       : HloTestBase(/*verifier_layout_sensitive=*/false,
50                     /*allow_mixed_precision_in_hlo_verifier=*/false) {}
51 };
52 
53 class HloVerifierTestAllowMixedPrecision : public HloTestBase {
54  public:
HloVerifierTestAllowMixedPrecision()55   HloVerifierTestAllowMixedPrecision()
56       : HloTestBase(/*verifier_layout_sensitive=*/false,
57                     /*allow_mixed_precision_in_hlo_verifier=*/true) {}
58 };
59 
60 class HloVerifierTestLayoutSensitive : public HloTestBase {
61  public:
HloVerifierTestLayoutSensitive()62   HloVerifierTestLayoutSensitive()
63       : HloTestBase(/*verifier_layout_sensitive=*/true,
64                     /*allow_mixed_precision_in_hlo_verifier=*/false,
65                     LayoutAssignment::InstructionCanChangeLayout) {}
66 };
67 
TEST_F(HloVerifierTest,NullInstructionParent)68 TEST_F(HloVerifierTest, NullInstructionParent) {
69   HloComputation::Builder builder(TestName());
70   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
71   HloInstruction* param = builder.AddInstruction(
72       HloInstruction::CreateParameter(0, scalar_shape, "param"));
73   HloInstruction* negate = builder.AddInstruction(
74       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
75   auto module = CreateUnverifiedModule();
76   module->AddEntryComputation(builder.Build());
77 
78   TF_ASSERT_OK(verifier().Run(module.get()).status());
79 
80   negate->set_parent(nullptr);
81 
82   auto status = verifier().Run(module.get()).status();
83   ASSERT_FALSE(status.ok());
84   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
85 }
86 
TEST_F(HloVerifierTest,NullComputationParent)87 TEST_F(HloVerifierTest, NullComputationParent) {
88   HloComputation::Builder builder(TestName());
89   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
90   HloInstruction* param = builder.AddInstruction(
91       HloInstruction::CreateParameter(0, scalar_shape, "param"));
92   builder.AddInstruction(
93       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
94   auto module = CreateUnverifiedModule();
95   HloComputation* computation = module->AddEntryComputation(builder.Build());
96 
97   TF_ASSERT_OK(verifier().Run(module.get()).status());
98 
99   computation->set_parent(nullptr);
100 
101   auto status = verifier().Run(module.get()).status();
102   ASSERT_FALSE(status.ok());
103   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
104 }
105 
TEST_F(HloVerifierTest,DifferentOperandParents)106 TEST_F(HloVerifierTest, DifferentOperandParents) {
107   HloComputation::Builder builder(TestName());
108   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
109   HloInstruction* param = builder.AddInstruction(
110       HloInstruction::CreateParameter(0, scalar_shape, "param"));
111   HloInstruction* negate = builder.AddInstruction(
112       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
113   auto module = CreateUnverifiedModule();
114   module->AddEntryComputation(builder.Build());
115 
116   HloComputation::Builder emb_builder(TestName());
117   HloInstruction* emb_param = emb_builder.AddInstruction(
118       HloInstruction::CreateParameter(0, scalar_shape, "param"));
119   module->AddEmbeddedComputation(emb_builder.Build());
120 
121   TF_ASSERT_OK(verifier().Run(module.get()).status());
122   TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
123 
124   auto status = verifier().Run(module.get()).status();
125   ASSERT_FALSE(status.ok());
126   EXPECT_THAT(status.error_message(),
127               HasSubstr("is in a different computation"));
128 }
129 
TEST_F(HloVerifierTest,ResetsShapeVerifierState)130 TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
131   HloComputation::Builder builder(TestName());
132   Shape s1 = ShapeUtil::MakeShape(F32, {1});
133   Shape s2 = ShapeUtil::MakeShape(F32, {2});
134 
135   HloInstruction* param =
136       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
137 
138   // Create an add instruction with the incorrect shape.
139   HloInstruction* add = builder.AddInstruction(
140       HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
141 
142   // In order to trigger the bug we're checking for, the instruction with the
143   // bad shape can't be the root of the computation.
144   builder.AddInstruction(
145       HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
146 
147   auto module = CreateUnverifiedModule();
148   module->AddEntryComputation(builder.Build());
149 
150   // Run the verifier twice.  It should fail both times, because it shouldn't
151   // carry state in its DFS visitor between runs.
152   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
153   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
154 }
155 
TEST_F(HloVerifierTest,CheckCallOperandParameterShapesMismatch)156 TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
157   const char* const hlo_string = R"(
158 HloModule Module
159 
160 callme {
161   ROOT param = (s32[], f32[4]) parameter(0)
162 }
163 
164 ENTRY entry {
165   p0 = (f32[4], s32[]) parameter(0)
166   ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
167 }
168 )";
169   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
170 
171   auto status = verifier().Run(module.get()).status();
172   ASSERT_FALSE(status.ok());
173   EXPECT_THAT(status.error_message(),
174               HasSubstr("shape does not match parameter"));
175 }
176 
TEST_F(HloVerifierTest,CheckConditionalOperandParameterShapesMismatch)177 TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
178   const char* const hlo_string = R"(
179 HloModule Module
180 
181 true_branch {
182   tparam = (s32[], f32[4]) parameter(0)
183   ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
184 }
185 
186 false_branch {
187   fparam = (s32[], f32[4]) parameter(0)
188   ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
189 }
190 
191 ENTRY entry {
192   p0 = (f32[4], s32[]) parameter(0)
193   constant = pred[] constant(true)
194   ROOT conditional = f32[4] conditional(constant, p0, p0),
195     true_computation=true_branch, false_computation=false_branch
196 }
197 )";
198   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
199 
200   auto status = verifier().Run(module.get()).status();
201   ASSERT_FALSE(status.ok());
202   EXPECT_THAT(status.error_message(),
203               HasSubstr("shape does not match parameter"));
204 }
205 
TEST_F(HloVerifierTest,RngOpnd0NotScalar)206 TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
207   const char* const hlo_string = R"(
208   HloModule Module
209 
210   ENTRY RngOpnd0NotScalar {
211    constant.0 = f32[] constant(0)
212    constant.1 = f16[2] constant({1, 3})
213    ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
214     distribution=rng_uniform
215   }
216   )";
217   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
218 
219   auto status = verifier().Run(module.get()).status();
220   ASSERT_FALSE(status.ok());
221   EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
222 }
223 
TEST_F(HloVerifierTest,RngOperandElementTypesDoNotMatch)224 TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
225   const char* const hlo_string = R"(
226   HloModule Module
227 
228   ENTRY RngOperandElementTypesNotMatch {
229    constant.0 = f32[] constant(0)
230    constant.1 = f16[] constant(1)
231    ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
232     distribution=rng_normal
233   }
234   )";
235   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
236 
237   auto status = verifier().Run(module.get()).status();
238   ASSERT_FALSE(status.ok());
239   EXPECT_THAT(status.error_message(),
240               HasSubstr("Expected compatible element types"));
241 }
242 
TEST_F(HloVerifierTest,RngMixedPrecisionNotAllowed)243 TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
244   const char* const hlo_string = R"(
245   HloModule Module
246 
247   ENTRY RngResultElementTypeNotMatch {
248    constant.0 = f32[] constant(0)
249    constant.1 = f32[] constant(1)
250    ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
251     distribution=rng_normal
252   }
253   )";
254   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
255 
256   auto status = verifier().Run(module.get()).status();
257   ASSERT_FALSE(status.ok());
258   EXPECT_THAT(status.error_message(),
259               HasSubstr("Expected compatible element types"));
260 }
261 
TEST_F(HloVerifierTestAllowMixedPrecision,RngMixedPrecisionAllowed)262 TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
263   const char* const hlo_string = R"(
264   HloModule Module
265 
266   ENTRY RngResultElementTypeNotMatch {
267    constant.0 = f32[] constant(0)
268    constant.1 = f32[] constant(1)
269    ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
270     distribution=rng_normal
271   }
272   )";
273   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
274 
275   auto status = verifier().Run(module.get()).status();
276   ASSERT_TRUE(status.ok());
277 }
278 
TEST_F(HloVerifierTest,RngElementTypeNotSupported)279 TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
280   const char* const hlo_string = R"(
281   HloModule Module
282 
283   ENTRY RngElementTypeNotSupported {
284    constant.0 = s32[] constant(0)
285    constant.1 = s32[] constant(1)
286    ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
287     distribution=rng_normal
288   }
289   )";
290   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
291 
292   auto status = verifier().Run(module.get()).status();
293   ASSERT_FALSE(status.ok());
294   EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
295 }
296 
TEST_F(HloVerifierTest,NegativeInteriorPaddingNotAllowed)297 TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
298   // This testcase can't be written using textual HLO, because it doesn't parse
299   // negative interior padding.  That's probably a feature.  :)
300   HloComputation::Builder builder(TestName());
301   HloInstruction* param =
302       builder.AddInstruction(HloInstruction::CreateParameter(
303           0, ShapeUtil::MakeShape(F32, {100}), "param"));
304   PaddingConfig padding_config;
305   padding_config.add_dimensions()->set_interior_padding(-1);
306   builder.AddInstruction(HloInstruction::CreatePad(
307       ShapeUtil::MakeShape(F32, {100}), param,
308       builder.AddInstruction(
309           HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
310       padding_config));
311 
312   auto module = CreateUnverifiedModule();
313   module->AddEntryComputation(builder.Build());
314 
315   auto status = verifier().Run(module.get()).status();
316   ASSERT_FALSE(status.ok());
317   EXPECT_THAT(status.error_message(),
318               HasSubstr("Interior padding cannot be negative"));
319 }
320 
TEST_F(HloVerifierTest,PadNegativeInteriorDilationNotAllowed)321 TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
322   // This testcase can't be written using textual HLO, because it doesn't parse
323   // negative interior padding.  That's probably a feature.  :)
324   HloComputation::Builder builder(TestName());
325   HloInstruction* param =
326       builder.AddInstruction(HloInstruction::CreateParameter(
327           0, ShapeUtil::MakeShape(F32, {100}), "param"));
328   PaddingConfig padding_config;
329   padding_config.add_dimensions()->set_interior_padding(-1);
330   builder.AddInstruction(HloInstruction::CreatePad(
331       ShapeUtil::MakeShape(F32, {100}), param,
332       builder.AddInstruction(
333           HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
334       padding_config));
335 
336   auto module = CreateUnverifiedModule();
337   module->AddEntryComputation(builder.Build());
338 
339   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
340               HasSubstr("Interior padding cannot be negative"));
341 }
342 
343 // Simple module containing a convolution as the root.
344 static const char* const kConvHloString = R"(
345 HloModule module
346 ENTRY entry_computation {
347   param0 = f16[128,128,56,56] parameter(0)
348   param1 = f16[3,3,128,128] parameter(1)
349   zero_f16 = f16[] constant(0)
350   ROOT conv = f16[128,128,28,28] convolution(param0, param1),
351     window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
352 })";
353 
TEST_F(HloVerifierTest,ConvNegativeWindowDilationNotAllowed)354 TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
355   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
356   auto* conv = module->entry_computation()->root_instruction();
357   Window w = conv->window();
358   w.mutable_dimensions(0)->set_window_dilation(-1);
359   conv->set_window(w);
360 
361   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
362               HasSubstr("non-positive window dilation factor"));
363 }
364 
TEST_F(HloVerifierTest,ConvNegativeBaseDilationNotAllowed)365 TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
366   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
367   auto* conv = module->entry_computation()->root_instruction();
368   Window w = conv->window();
369   w.mutable_dimensions(0)->set_base_dilation(-1);
370   conv->set_window(w);
371 
372   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
373               HasSubstr("non-positive base area dilation factor"));
374 }
375 
376 static const char* const kAddWithLayoutChangeHlo = R"(
377    HloModule AddWithLayoutChange
378     ENTRY AddWithLayoutChange {
379       par0 = f32[3,4]{1,0} parameter(0)
380       par1 = f32[3,4]{0,1} parameter(1)
381       ROOT add0 = f32[3,4]{1,0} add(par0,par1)
382     }
383   )";
384 
TEST_F(HloVerifierTest,AddWithLayoutChange)385 TEST_F(HloVerifierTest, AddWithLayoutChange) {
386   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
387   auto status = verifier().Run(module.get()).status();
388   ASSERT_TRUE(status.ok());
389 }
390 
TEST_F(HloVerifierTest,ScalarIndexDynamicSlice)391 TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
392   const char* const kScalarIndexDynamicSlice = R"(
393     HloModule DynamicSlice_module
394 
395     ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
396       %original_parameter = s32[2,2,258] parameter(0)
397       %constant = s32[] constant(0)
398       %start_index = s32[] parameter(1)
399       ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
400     }
401   )";
402 
403   HloModuleConfig config;
404   DebugOptions debug_options = config.debug_options();
405   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
406   config.set_debug_options(debug_options);
407 
408   TF_ASSERT_OK_AND_ASSIGN(auto module,
409                           ParseHloString(kScalarIndexDynamicSlice, config));
410   auto status = verifier().Run(module.get()).status();
411   ASSERT_TRUE(status.ok());
412 }
413 
TEST_F(HloVerifierTest,ScalarIndexDynamicUpdateSlice)414 TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
415   const char* const kScalarIndexDynamicSlice = R"(
416     HloModule DynamicUpdateSlice_module
417 
418     ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
419       %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
420       %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
421       %start_index.0 = s32[] parameter(2)
422       %start_index.1 = s32[] parameter(3)
423       %start_index.2 = s32[] parameter(4)
424       %start_index.3 = s32[] parameter(5)
425       ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
426     }
427   )";
428 
429   HloModuleConfig config;
430   DebugOptions debug_options = config.debug_options();
431   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
432   config.set_debug_options(debug_options);
433 
434   TF_ASSERT_OK_AND_ASSIGN(auto module,
435                           ParseHloString(kScalarIndexDynamicSlice, config));
436   auto status = verifier().Run(module.get()).status();
437   ASSERT_TRUE(status.ok());
438 }
439 
TEST_F(HloVerifierTestLayoutSensitive,AddWithLayoutChangeNotAllowed)440 TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
441   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
442   auto status = verifier().Run(module.get()).status();
443   ASSERT_FALSE(status.ok());
444   EXPECT_THAT(status.error_message(),
445               HasSubstr("Instruction shouldn't change layouts"));
446 }
447 
TEST_F(HloVerifierTestLayoutSensitive,SliceWithLayoutChangeNotAllowed)448 TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
449   const char* const kSliceWithLayoutChangeHlo = R"(
450    HloModule SliceWithLayoutChange
451     ENTRY SliceWithLayoutChange {
452       par0 = f32[4,5]{0,1} parameter(0)
453       par1 = s32[] parameter(1)
454       par2 = s32[] parameter(2)
455       ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
456         dynamic_slice_sizes={3,4}
457     }
458   )";
459   TF_ASSERT_OK_AND_ASSIGN(auto module,
460                           ParseHloString(kSliceWithLayoutChangeHlo));
461   auto status = verifier().Run(module.get()).status();
462   ASSERT_FALSE(status.ok());
463   EXPECT_THAT(status.error_message(),
464               HasSubstr("Instruction shouldn't change layouts"));
465 }
466 
TEST_F(HloVerifierTestLayoutSensitive,ConcatWithLayoutChangeNotAllowed)467 TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
468   const char* const kConcatWithLayoutChangeHlo = R"(
469    HloModule ConcatWithLayoutChange
470    ENTRY ConcatWithLayoutChange {
471       par0 = f32[3,5]{0,1} parameter(0)
472       par1 = f32[3,3]{1,0} parameter(1)
473       ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
474         dimensions={1}
475    }
476   )";
477   TF_ASSERT_OK_AND_ASSIGN(auto module,
478                           ParseHloString(kConcatWithLayoutChangeHlo));
479   auto status = verifier().Run(module.get()).status();
480   ASSERT_FALSE(status.ok());
481   EXPECT_THAT(status.error_message(),
482               HasSubstr("Instruction shouldn't change layouts"));
483 }
484 
TEST_F(HloVerifierTest,BitcastCanNotChangeElementType)485 TEST_F(HloVerifierTest, BitcastCanNotChangeElementType) {
486   const char* const hlo_string = R"(
487   HloModule Module
488 
489   ENTRY BitcastCanNotChangeElementType {
490    constant.0 = f32[2] constant({0.0, 0.0})
491    ROOT bitcast = s32[2] bitcast(constant.0)
492   }
493   )";
494   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
495 
496   auto status = verifier().Run(module.get()).status();
497   ASSERT_FALSE(status.ok());
498   EXPECT_THAT(status.error_message(),
499               HasSubstr("Bitcast can not change the element type"));
500 }
501 
TEST_F(HloVerifierTest,SelectMixedPrecisionNotAllowed)502 TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) {
503   const char* const hlo_string = R"(
504   HloModule Module
505 
506   ENTRY SelectMixedPrecisionNotAllowed {
507    p0 = pred[] parameter(0)
508    p1 = f32[32] parameter(1)
509    p2 = bf16[32] parameter(2)
510    ROOT select = f32[32] select(p0, p1, p2)
511   }
512   )";
513   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
514 
515   auto status = verifier().Run(module.get()).status();
516   ASSERT_FALSE(status.ok());
517   EXPECT_THAT(status.error_message(),
518               HasSubstr("Seen floating point types of different precisions"));
519 }
520 
TEST_F(HloVerifierTestAllowMixedPrecision,SelectMixedPrecisionAllowed)521 TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
522   const char* const hlo_string = R"(
523   HloModule Module
524 
525   ENTRY SelectMixedPrecisionAllowed {
526    p0 = pred[] parameter(0)
527    p1 = f32[32] parameter(1)
528    p2 = bf16[32] parameter(2)
529    ROOT select = f32[32] select(p0, p1, p2)
530   }
531   )";
532   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
533 
534   auto status = verifier().Run(module.get()).status();
535   ASSERT_TRUE(status.ok());
536 }
537 
TEST_F(HloVerifierTest,IotaNonArrayResult)538 TEST_F(HloVerifierTest, IotaNonArrayResult) {
539   const char* const hlo_string = R"(
540   HloModule IotaTupleResult
541 
542   ENTRY  kernelEntry {
543     ROOT iota = () iota(), iota_dimension=24
544   }
545   )";
546 
547   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
548 
549   auto status = verifier().Run(module.get()).status();
550   ASSERT_FALSE(status.ok());
551   EXPECT_THAT(status.error_message(),
552               HasSubstr("does not support non-array result"));
553 }
554 
555 static const char* const kMapOperandComputationMismatchHlo = R"(
556   HloModule MapOperandComputationMismatch
557 
558   Computation {
559     param0 = f32[] parameter(0)
560     constant = f32[] constant(1)
561     ROOT add = f32[] add(param0, constant)
562   }
563 
564   ENTRY kernelEntry {
565   param = f64[] parameter(0)
566   ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
567 })";
568 
TEST_F(HloVerifierTest,MapOperandComputationMismatch)569 TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
570   TF_ASSERT_OK_AND_ASSIGN(auto module,
571                           ParseHloString(kMapOperandComputationMismatchHlo));
572   auto status = verifier().Run(module.get()).status();
573   ASSERT_FALSE(status.ok());
574   EXPECT_THAT(
575       status.error_message(),
576       HasSubstr(
577           "Shape mismatch between to_apply computation parameter and operand"));
578 }
579 
TEST_F(HloVerifierTestAllowMixedPrecision,MapOperandComputationMismatch)580 TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
581   TF_ASSERT_OK_AND_ASSIGN(auto module,
582                           ParseHloString(kMapOperandComputationMismatchHlo));
583   auto status = verifier().Run(module.get()).status();
584   ASSERT_TRUE(status.ok());
585 }
586 
587 static const char* const kReduceOperandComputationMismatchHlo = R"(
588   HloModule ReduceOperandComputationMismatch
589   computation {
590     x = f32[] parameter(0)
591     y = f32[] parameter(1)
592     ROOT add = f32[] add(x, y)
593   }
594 
595   ENTRY kernelEntry {
596     arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
597     constant = f16[] constant(0)
598     reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
599   })";
600 
TEST_F(HloVerifierTest,ReduceOperandComputationMismatch)601 TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
602   TF_ASSERT_OK_AND_ASSIGN(auto module,
603                           ParseHloString(kReduceOperandComputationMismatchHlo));
604   auto status = verifier().Run(module.get()).status();
605   ASSERT_FALSE(status.ok());
606   EXPECT_THAT(status.error_message(),
607               HasSubstr("Expected instruction to have shape equal to f32[64]"));
608 }
609 
TEST_F(HloVerifierTestAllowMixedPrecision,ReduceOperandComputationMismatch)610 TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
611   TF_ASSERT_OK_AND_ASSIGN(auto module,
612                           ParseHloString(kReduceOperandComputationMismatchHlo));
613   auto status = verifier().Run(module.get()).status();
614   ASSERT_TRUE(status.ok());
615 }
616 
617 }  // namespace
618 }  // namespace xla
619