1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/layout_assignment.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla.pb.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34
35 namespace xla {
36 namespace {
37
38 using ::testing::HasSubstr;
39
CreateUnverifiedModule()40 std::unique_ptr<HloModule> CreateUnverifiedModule() {
41 return absl::make_unique<HloModule>("module", HloModuleConfig());
42 }
43
44 // This class cannot be converted to use HloTestBase. It explicitly
45 // uses HloTestBase to create and test malformed HLOs.
46 class HloVerifierTest : public HloTestBase {
47 public:
HloVerifierTest()48 HloVerifierTest()
49 : HloTestBase(/*verifier_layout_sensitive=*/false,
50 /*allow_mixed_precision_in_hlo_verifier=*/false) {}
51 };
52
53 class HloVerifierTestAllowMixedPrecision : public HloTestBase {
54 public:
HloVerifierTestAllowMixedPrecision()55 HloVerifierTestAllowMixedPrecision()
56 : HloTestBase(/*verifier_layout_sensitive=*/false,
57 /*allow_mixed_precision_in_hlo_verifier=*/true) {}
58 };
59
60 class HloVerifierTestLayoutSensitive : public HloTestBase {
61 public:
HloVerifierTestLayoutSensitive()62 HloVerifierTestLayoutSensitive()
63 : HloTestBase(/*verifier_layout_sensitive=*/true,
64 /*allow_mixed_precision_in_hlo_verifier=*/false,
65 LayoutAssignment::InstructionCanChangeLayout) {}
66 };
67
TEST_F(HloVerifierTest,NullInstructionParent)68 TEST_F(HloVerifierTest, NullInstructionParent) {
69 HloComputation::Builder builder(TestName());
70 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
71 HloInstruction* param = builder.AddInstruction(
72 HloInstruction::CreateParameter(0, scalar_shape, "param"));
73 HloInstruction* negate = builder.AddInstruction(
74 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
75 auto module = CreateUnverifiedModule();
76 module->AddEntryComputation(builder.Build());
77
78 TF_ASSERT_OK(verifier().Run(module.get()).status());
79
80 negate->set_parent(nullptr);
81
82 auto status = verifier().Run(module.get()).status();
83 ASSERT_FALSE(status.ok());
84 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
85 }
86
TEST_F(HloVerifierTest,NullComputationParent)87 TEST_F(HloVerifierTest, NullComputationParent) {
88 HloComputation::Builder builder(TestName());
89 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
90 HloInstruction* param = builder.AddInstruction(
91 HloInstruction::CreateParameter(0, scalar_shape, "param"));
92 builder.AddInstruction(
93 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
94 auto module = CreateUnverifiedModule();
95 HloComputation* computation = module->AddEntryComputation(builder.Build());
96
97 TF_ASSERT_OK(verifier().Run(module.get()).status());
98
99 computation->set_parent(nullptr);
100
101 auto status = verifier().Run(module.get()).status();
102 ASSERT_FALSE(status.ok());
103 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
104 }
105
TEST_F(HloVerifierTest,DifferentOperandParents)106 TEST_F(HloVerifierTest, DifferentOperandParents) {
107 HloComputation::Builder builder(TestName());
108 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
109 HloInstruction* param = builder.AddInstruction(
110 HloInstruction::CreateParameter(0, scalar_shape, "param"));
111 HloInstruction* negate = builder.AddInstruction(
112 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
113 auto module = CreateUnverifiedModule();
114 module->AddEntryComputation(builder.Build());
115
116 HloComputation::Builder emb_builder(TestName());
117 HloInstruction* emb_param = emb_builder.AddInstruction(
118 HloInstruction::CreateParameter(0, scalar_shape, "param"));
119 module->AddEmbeddedComputation(emb_builder.Build());
120
121 TF_ASSERT_OK(verifier().Run(module.get()).status());
122 TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
123
124 auto status = verifier().Run(module.get()).status();
125 ASSERT_FALSE(status.ok());
126 EXPECT_THAT(status.error_message(),
127 HasSubstr("is in a different computation"));
128 }
129
TEST_F(HloVerifierTest,ResetsShapeVerifierState)130 TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
131 HloComputation::Builder builder(TestName());
132 Shape s1 = ShapeUtil::MakeShape(F32, {1});
133 Shape s2 = ShapeUtil::MakeShape(F32, {2});
134
135 HloInstruction* param =
136 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
137
138 // Create an add instruction with the incorrect shape.
139 HloInstruction* add = builder.AddInstruction(
140 HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
141
142 // In order to trigger the bug we're checking for, the instruction with the
143 // bad shape can't be the root of the computation.
144 builder.AddInstruction(
145 HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
146
147 auto module = CreateUnverifiedModule();
148 module->AddEntryComputation(builder.Build());
149
150 // Run the verifier twice. It should fail both times, because it shouldn't
151 // carry state in its DFS visitor between runs.
152 EXPECT_FALSE(verifier().Run(module.get()).status().ok());
153 EXPECT_FALSE(verifier().Run(module.get()).status().ok());
154 }
155
TEST_F(HloVerifierTest,CheckCallOperandParameterShapesMismatch)156 TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
157 const char* const hlo_string = R"(
158 HloModule Module
159
160 callme {
161 ROOT param = (s32[], f32[4]) parameter(0)
162 }
163
164 ENTRY entry {
165 p0 = (f32[4], s32[]) parameter(0)
166 ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
167 }
168 )";
169 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
170
171 auto status = verifier().Run(module.get()).status();
172 ASSERT_FALSE(status.ok());
173 EXPECT_THAT(status.error_message(),
174 HasSubstr("shape does not match parameter"));
175 }
176
TEST_F(HloVerifierTest,CheckConditionalOperandParameterShapesMismatch)177 TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
178 const char* const hlo_string = R"(
179 HloModule Module
180
181 true_branch {
182 tparam = (s32[], f32[4]) parameter(0)
183 ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
184 }
185
186 false_branch {
187 fparam = (s32[], f32[4]) parameter(0)
188 ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
189 }
190
191 ENTRY entry {
192 p0 = (f32[4], s32[]) parameter(0)
193 constant = pred[] constant(true)
194 ROOT conditional = f32[4] conditional(constant, p0, p0),
195 true_computation=true_branch, false_computation=false_branch
196 }
197 )";
198 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
199
200 auto status = verifier().Run(module.get()).status();
201 ASSERT_FALSE(status.ok());
202 EXPECT_THAT(status.error_message(),
203 HasSubstr("shape does not match parameter"));
204 }
205
TEST_F(HloVerifierTest,RngOpnd0NotScalar)206 TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
207 const char* const hlo_string = R"(
208 HloModule Module
209
210 ENTRY RngOpnd0NotScalar {
211 constant.0 = f32[] constant(0)
212 constant.1 = f16[2] constant({1, 3})
213 ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
214 distribution=rng_uniform
215 }
216 )";
217 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
218
219 auto status = verifier().Run(module.get()).status();
220 ASSERT_FALSE(status.ok());
221 EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
222 }
223
TEST_F(HloVerifierTest,RngOperandElementTypesDoNotMatch)224 TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
225 const char* const hlo_string = R"(
226 HloModule Module
227
228 ENTRY RngOperandElementTypesNotMatch {
229 constant.0 = f32[] constant(0)
230 constant.1 = f16[] constant(1)
231 ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
232 distribution=rng_normal
233 }
234 )";
235 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
236
237 auto status = verifier().Run(module.get()).status();
238 ASSERT_FALSE(status.ok());
239 EXPECT_THAT(status.error_message(),
240 HasSubstr("Expected compatible element types"));
241 }
242
TEST_F(HloVerifierTest,RngMixedPrecisionNotAllowed)243 TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
244 const char* const hlo_string = R"(
245 HloModule Module
246
247 ENTRY RngResultElementTypeNotMatch {
248 constant.0 = f32[] constant(0)
249 constant.1 = f32[] constant(1)
250 ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
251 distribution=rng_normal
252 }
253 )";
254 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
255
256 auto status = verifier().Run(module.get()).status();
257 ASSERT_FALSE(status.ok());
258 EXPECT_THAT(status.error_message(),
259 HasSubstr("Expected compatible element types"));
260 }
261
TEST_F(HloVerifierTestAllowMixedPrecision,RngMixedPrecisionAllowed)262 TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
263 const char* const hlo_string = R"(
264 HloModule Module
265
266 ENTRY RngResultElementTypeNotMatch {
267 constant.0 = f32[] constant(0)
268 constant.1 = f32[] constant(1)
269 ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
270 distribution=rng_normal
271 }
272 )";
273 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
274
275 auto status = verifier().Run(module.get()).status();
276 ASSERT_TRUE(status.ok());
277 }
278
TEST_F(HloVerifierTest,RngElementTypeNotSupported)279 TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
280 const char* const hlo_string = R"(
281 HloModule Module
282
283 ENTRY RngElementTypeNotSupported {
284 constant.0 = s32[] constant(0)
285 constant.1 = s32[] constant(1)
286 ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
287 distribution=rng_normal
288 }
289 )";
290 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
291
292 auto status = verifier().Run(module.get()).status();
293 ASSERT_FALSE(status.ok());
294 EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
295 }
296
TEST_F(HloVerifierTest,NegativeInteriorPaddingNotAllowed)297 TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
298 // This testcase can't be written using textual HLO, because it doesn't parse
299 // negative interior padding. That's probably a feature. :)
300 HloComputation::Builder builder(TestName());
301 HloInstruction* param =
302 builder.AddInstruction(HloInstruction::CreateParameter(
303 0, ShapeUtil::MakeShape(F32, {100}), "param"));
304 PaddingConfig padding_config;
305 padding_config.add_dimensions()->set_interior_padding(-1);
306 builder.AddInstruction(HloInstruction::CreatePad(
307 ShapeUtil::MakeShape(F32, {100}), param,
308 builder.AddInstruction(
309 HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
310 padding_config));
311
312 auto module = CreateUnverifiedModule();
313 module->AddEntryComputation(builder.Build());
314
315 auto status = verifier().Run(module.get()).status();
316 ASSERT_FALSE(status.ok());
317 EXPECT_THAT(status.error_message(),
318 HasSubstr("Interior padding cannot be negative"));
319 }
320
TEST_F(HloVerifierTest,PadNegativeInteriorDilationNotAllowed)321 TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
322 // This testcase can't be written using textual HLO, because it doesn't parse
323 // negative interior padding. That's probably a feature. :)
324 HloComputation::Builder builder(TestName());
325 HloInstruction* param =
326 builder.AddInstruction(HloInstruction::CreateParameter(
327 0, ShapeUtil::MakeShape(F32, {100}), "param"));
328 PaddingConfig padding_config;
329 padding_config.add_dimensions()->set_interior_padding(-1);
330 builder.AddInstruction(HloInstruction::CreatePad(
331 ShapeUtil::MakeShape(F32, {100}), param,
332 builder.AddInstruction(
333 HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
334 padding_config));
335
336 auto module = CreateUnverifiedModule();
337 module->AddEntryComputation(builder.Build());
338
339 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
340 HasSubstr("Interior padding cannot be negative"));
341 }
342
343 // Simple module containing a convolution as the root.
344 static const char* const kConvHloString = R"(
345 HloModule module
346 ENTRY entry_computation {
347 param0 = f16[128,128,56,56] parameter(0)
348 param1 = f16[3,3,128,128] parameter(1)
349 zero_f16 = f16[] constant(0)
350 ROOT conv = f16[128,128,28,28] convolution(param0, param1),
351 window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
352 })";
353
TEST_F(HloVerifierTest,ConvNegativeWindowDilationNotAllowed)354 TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
355 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
356 auto* conv = module->entry_computation()->root_instruction();
357 Window w = conv->window();
358 w.mutable_dimensions(0)->set_window_dilation(-1);
359 conv->set_window(w);
360
361 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
362 HasSubstr("non-positive window dilation factor"));
363 }
364
TEST_F(HloVerifierTest,ConvNegativeBaseDilationNotAllowed)365 TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
366 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString));
367 auto* conv = module->entry_computation()->root_instruction();
368 Window w = conv->window();
369 w.mutable_dimensions(0)->set_base_dilation(-1);
370 conv->set_window(w);
371
372 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
373 HasSubstr("non-positive base area dilation factor"));
374 }
375
376 static const char* const kAddWithLayoutChangeHlo = R"(
377 HloModule AddWithLayoutChange
378 ENTRY AddWithLayoutChange {
379 par0 = f32[3,4]{1,0} parameter(0)
380 par1 = f32[3,4]{0,1} parameter(1)
381 ROOT add0 = f32[3,4]{1,0} add(par0,par1)
382 }
383 )";
384
TEST_F(HloVerifierTest,AddWithLayoutChange)385 TEST_F(HloVerifierTest, AddWithLayoutChange) {
386 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
387 auto status = verifier().Run(module.get()).status();
388 ASSERT_TRUE(status.ok());
389 }
390
TEST_F(HloVerifierTest,ScalarIndexDynamicSlice)391 TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
392 const char* const kScalarIndexDynamicSlice = R"(
393 HloModule DynamicSlice_module
394
395 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
396 %original_parameter = s32[2,2,258] parameter(0)
397 %constant = s32[] constant(0)
398 %start_index = s32[] parameter(1)
399 ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
400 }
401 )";
402
403 HloModuleConfig config;
404 DebugOptions debug_options = config.debug_options();
405 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
406 config.set_debug_options(debug_options);
407
408 TF_ASSERT_OK_AND_ASSIGN(auto module,
409 ParseHloString(kScalarIndexDynamicSlice, config));
410 auto status = verifier().Run(module.get()).status();
411 ASSERT_TRUE(status.ok());
412 }
413
TEST_F(HloVerifierTest,ScalarIndexDynamicUpdateSlice)414 TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
415 const char* const kScalarIndexDynamicSlice = R"(
416 HloModule DynamicUpdateSlice_module
417
418 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
419 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
420 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
421 %start_index.0 = s32[] parameter(2)
422 %start_index.1 = s32[] parameter(3)
423 %start_index.2 = s32[] parameter(4)
424 %start_index.3 = s32[] parameter(5)
425 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
426 }
427 )";
428
429 HloModuleConfig config;
430 DebugOptions debug_options = config.debug_options();
431 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
432 config.set_debug_options(debug_options);
433
434 TF_ASSERT_OK_AND_ASSIGN(auto module,
435 ParseHloString(kScalarIndexDynamicSlice, config));
436 auto status = verifier().Run(module.get()).status();
437 ASSERT_TRUE(status.ok());
438 }
439
TEST_F(HloVerifierTestLayoutSensitive,AddWithLayoutChangeNotAllowed)440 TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
441 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
442 auto status = verifier().Run(module.get()).status();
443 ASSERT_FALSE(status.ok());
444 EXPECT_THAT(status.error_message(),
445 HasSubstr("Instruction shouldn't change layouts"));
446 }
447
TEST_F(HloVerifierTestLayoutSensitive,SliceWithLayoutChangeNotAllowed)448 TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
449 const char* const kSliceWithLayoutChangeHlo = R"(
450 HloModule SliceWithLayoutChange
451 ENTRY SliceWithLayoutChange {
452 par0 = f32[4,5]{0,1} parameter(0)
453 par1 = s32[] parameter(1)
454 par2 = s32[] parameter(2)
455 ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
456 dynamic_slice_sizes={3,4}
457 }
458 )";
459 TF_ASSERT_OK_AND_ASSIGN(auto module,
460 ParseHloString(kSliceWithLayoutChangeHlo));
461 auto status = verifier().Run(module.get()).status();
462 ASSERT_FALSE(status.ok());
463 EXPECT_THAT(status.error_message(),
464 HasSubstr("Instruction shouldn't change layouts"));
465 }
466
TEST_F(HloVerifierTestLayoutSensitive,ConcatWithLayoutChangeNotAllowed)467 TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
468 const char* const kConcatWithLayoutChangeHlo = R"(
469 HloModule ConcatWithLayoutChange
470 ENTRY ConcatWithLayoutChange {
471 par0 = f32[3,5]{0,1} parameter(0)
472 par1 = f32[3,3]{1,0} parameter(1)
473 ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
474 dimensions={1}
475 }
476 )";
477 TF_ASSERT_OK_AND_ASSIGN(auto module,
478 ParseHloString(kConcatWithLayoutChangeHlo));
479 auto status = verifier().Run(module.get()).status();
480 ASSERT_FALSE(status.ok());
481 EXPECT_THAT(status.error_message(),
482 HasSubstr("Instruction shouldn't change layouts"));
483 }
484
TEST_F(HloVerifierTest,BitcastCanNotChangeElementType)485 TEST_F(HloVerifierTest, BitcastCanNotChangeElementType) {
486 const char* const hlo_string = R"(
487 HloModule Module
488
489 ENTRY BitcastCanNotChangeElementType {
490 constant.0 = f32[2] constant({0.0, 0.0})
491 ROOT bitcast = s32[2] bitcast(constant.0)
492 }
493 )";
494 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
495
496 auto status = verifier().Run(module.get()).status();
497 ASSERT_FALSE(status.ok());
498 EXPECT_THAT(status.error_message(),
499 HasSubstr("Bitcast can not change the element type"));
500 }
501
TEST_F(HloVerifierTest,SelectMixedPrecisionNotAllowed)502 TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) {
503 const char* const hlo_string = R"(
504 HloModule Module
505
506 ENTRY SelectMixedPrecisionNotAllowed {
507 p0 = pred[] parameter(0)
508 p1 = f32[32] parameter(1)
509 p2 = bf16[32] parameter(2)
510 ROOT select = f32[32] select(p0, p1, p2)
511 }
512 )";
513 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
514
515 auto status = verifier().Run(module.get()).status();
516 ASSERT_FALSE(status.ok());
517 EXPECT_THAT(status.error_message(),
518 HasSubstr("Seen floating point types of different precisions"));
519 }
520
TEST_F(HloVerifierTestAllowMixedPrecision,SelectMixedPrecisionAllowed)521 TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
522 const char* const hlo_string = R"(
523 HloModule Module
524
525 ENTRY SelectMixedPrecisionAllowed {
526 p0 = pred[] parameter(0)
527 p1 = f32[32] parameter(1)
528 p2 = bf16[32] parameter(2)
529 ROOT select = f32[32] select(p0, p1, p2)
530 }
531 )";
532 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
533
534 auto status = verifier().Run(module.get()).status();
535 ASSERT_TRUE(status.ok());
536 }
537
TEST_F(HloVerifierTest,IotaNonArrayResult)538 TEST_F(HloVerifierTest, IotaNonArrayResult) {
539 const char* const hlo_string = R"(
540 HloModule IotaTupleResult
541
542 ENTRY kernelEntry {
543 ROOT iota = () iota(), iota_dimension=24
544 }
545 )";
546
547 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
548
549 auto status = verifier().Run(module.get()).status();
550 ASSERT_FALSE(status.ok());
551 EXPECT_THAT(status.error_message(),
552 HasSubstr("does not support non-array result"));
553 }
554
555 static const char* const kMapOperandComputationMismatchHlo = R"(
556 HloModule MapOperandComputationMismatch
557
558 Computation {
559 param0 = f32[] parameter(0)
560 constant = f32[] constant(1)
561 ROOT add = f32[] add(param0, constant)
562 }
563
564 ENTRY kernelEntry {
565 param = f64[] parameter(0)
566 ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
567 })";
568
TEST_F(HloVerifierTest,MapOperandComputationMismatch)569 TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
570 TF_ASSERT_OK_AND_ASSIGN(auto module,
571 ParseHloString(kMapOperandComputationMismatchHlo));
572 auto status = verifier().Run(module.get()).status();
573 ASSERT_FALSE(status.ok());
574 EXPECT_THAT(
575 status.error_message(),
576 HasSubstr(
577 "Shape mismatch between to_apply computation parameter and operand"));
578 }
579
TEST_F(HloVerifierTestAllowMixedPrecision,MapOperandComputationMismatch)580 TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
581 TF_ASSERT_OK_AND_ASSIGN(auto module,
582 ParseHloString(kMapOperandComputationMismatchHlo));
583 auto status = verifier().Run(module.get()).status();
584 ASSERT_TRUE(status.ok());
585 }
586
587 static const char* const kReduceOperandComputationMismatchHlo = R"(
588 HloModule ReduceOperandComputationMismatch
589 computation {
590 x = f32[] parameter(0)
591 y = f32[] parameter(1)
592 ROOT add = f32[] add(x, y)
593 }
594
595 ENTRY kernelEntry {
596 arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
597 constant = f16[] constant(0)
598 reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
599 })";
600
TEST_F(HloVerifierTest,ReduceOperandComputationMismatch)601 TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
602 TF_ASSERT_OK_AND_ASSIGN(auto module,
603 ParseHloString(kReduceOperandComputationMismatchHlo));
604 auto status = verifier().Run(module.get()).status();
605 ASSERT_FALSE(status.ok());
606 EXPECT_THAT(status.error_message(),
607 HasSubstr("Expected instruction to have shape equal to f32[64]"));
608 }
609
TEST_F(HloVerifierTestAllowMixedPrecision,ReduceOperandComputationMismatch)610 TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
611 TF_ASSERT_OK_AND_ASSIGN(auto module,
612 ParseHloString(kReduceOperandComputationMismatchHlo));
613 auto status = verifier().Run(module.get()).status();
614 ASSERT_TRUE(status.ok());
615 }
616
617 } // namespace
618 } // namespace xla
619