1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
17 #include "absl/strings/str_cat.h"
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace xla {
25 namespace {
26
27 namespace m = match;
28
TEST(PatternMatcherTest,AddOp)29 TEST(PatternMatcherTest, AddOp) {
30 constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
31 ENTRY %two_plus_two_computation () -> f32[] {
32 %two = f32[] constant(2)
33 ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
34 }
35 )";
36 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
37
38 const HloInstruction* matched_inst;
39 HloInstruction* matched_operand;
40 Shape* matched_shape;
41 Layout* matched_layout;
42
43 ASSERT_TRUE(Match(
44 hlo_module->entry_computation()->root_instruction(),
45 match::Op(&matched_inst)
46 .WithName("two_plus_two")
47 .WithOpcode(HloOpcode::kAdd)
48 .WithShape(
49 match::Shape(&matched_shape)
50 .WithLayout(match::Layout(&matched_layout).WithDenseFormat()))
51 .WithOperand(
52 0,
53 match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
54 ASSERT_NE(matched_inst, nullptr);
55 EXPECT_EQ(matched_inst->name(), "two_plus_two");
56 EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
57
58 EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
59 match::Add(match::Constant(), match::Constant())));
60
61 EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
62 match::Op().WithName("bad_name")));
63 matched_inst = nullptr;
64 EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
65 match::Multiply(&matched_inst, match::Op(), match::Op())));
66 }
67
TEST(PatternMatcherTest,ScalarShape)68 TEST(PatternMatcherTest, ScalarShape) {
69 auto scalar_shape = ShapeUtil::MakeShape(F32, {});
70 Shape* matched_shape;
71 EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
72 EXPECT_EQ(matched_shape, &scalar_shape);
73 EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
74 EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray()));
75 EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
76 EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
77 EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
78 EXPECT_FALSE(Match(
79 &scalar_shape,
80 match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
81 }
82
TEST(PatternMatcherTest,DenseArrayShape)83 TEST(PatternMatcherTest, DenseArrayShape) {
84 auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
85 Shape* matched_shape;
86 EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
87 EXPECT_EQ(matched_shape, &array_shape);
88 EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
89 EXPECT_FALSE(Match(&array_shape, match::Shape().IsSparseArray()));
90 EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
91 EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
92 EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
93 EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
94 EXPECT_FALSE(
95 Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
96 Layout* matched_layout;
97 EXPECT_FALSE(Match(&array_shape,
98 match::Shape().WithLayout(
99 match::Layout(&matched_layout).WithSparseFormat())));
100 EXPECT_TRUE(Match(&array_shape,
101 match::Shape().WithLayout(
102 match::Layout(&matched_layout).WithDenseFormat())));
103 EXPECT_EQ(matched_layout, &array_shape.layout());
104 }
105
TEST(PatternMatcherTest,SparseArrayShape)106 TEST(PatternMatcherTest, SparseArrayShape) {
107 auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10);
108 Shape* matched_shape;
109 EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
110 EXPECT_EQ(matched_shape, &array_shape);
111 EXPECT_FALSE(Match(&array_shape, match::Shape().IsDenseArray()));
112 EXPECT_TRUE(Match(&array_shape, match::Shape().IsSparseArray()));
113 EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
114 EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
115 EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
116 EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
117 EXPECT_FALSE(
118 Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
119 Layout* matched_layout;
120 EXPECT_FALSE(Match(&array_shape,
121 match::Shape().WithLayout(
122 match::Layout(&matched_layout).WithDenseFormat())));
123 EXPECT_TRUE(Match(&array_shape,
124 match::Shape().WithLayout(
125 match::Layout(&matched_layout).WithSparseFormat())));
126 EXPECT_EQ(matched_layout, &array_shape.layout());
127 }
128
TEST(PatternMatcherTest,TupleShape)129 TEST(PatternMatcherTest, TupleShape) {
130 auto tuple_shape = ShapeUtil::MakeTupleShape({
131 ShapeUtil::MakeShape(F32, {1, 2, 3}),
132 ShapeUtil::MakeShape(S32, {4, 5}),
133 });
134 EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
135 EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
136 EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
137
138 Shape* subshape;
139 ASSERT_TRUE(Match(
140 &tuple_shape,
141 match::Shape().WithSubshape(
142 {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
143 ASSERT_NE(subshape, nullptr);
144 EXPECT_TRUE(
145 ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
146 EXPECT_TRUE(Match(&tuple_shape,
147 match::Shape().WithSubshape(
148 {0}, match::Shape().EqualTo(
149 &ShapeUtil::GetSubshape(tuple_shape, {0})))));
150 EXPECT_FALSE(Match(&tuple_shape,
151 match::Shape().WithSubshape(
152 {0}, match::Shape().EqualTo(
153 &ShapeUtil::GetSubshape(tuple_shape, {1})))));
154
155 ASSERT_TRUE(Match(
156 &tuple_shape,
157 match::Shape().WithSubshape(
158 {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
159 ASSERT_NE(subshape, nullptr);
160 EXPECT_TRUE(
161 ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
162 EXPECT_TRUE(Match(&tuple_shape,
163 match::Shape().WithSubshape(
164 {1}, match::Shape().EqualTo(
165 &ShapeUtil::GetSubshape(tuple_shape, {1})))));
166 EXPECT_FALSE(Match(&tuple_shape,
167 match::Shape().WithSubshape(
168 {1}, match::Shape().EqualTo(
169 &ShapeUtil::GetSubshape(tuple_shape, {0})))));
170
171 EXPECT_FALSE(
172 Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
173 EXPECT_FALSE(
174 Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
175 }
176
TEST(PatternMatcherTest,FusionKind)177 TEST(PatternMatcherTest, FusionKind) {
178 constexpr char kModuleStr[] = R"(
179 HloModule test_module
180
181 fused_computation {
182 ROOT fp0 = f32[] parameter(0)
183 }
184
185 ENTRY while.v11 {
186 p0 = f32[] parameter(0)
187 ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
188 })";
189 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
190
191 auto* root = hlo_module->entry_computation()->root_instruction();
192 EXPECT_TRUE(Match(
193 root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
194 EXPECT_FALSE(Match(
195 root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
196 EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
197 HloInstruction::FusionKind::kLoop)));
198 }
199
TEST(PatternMatcherTest,GetTupleElement)200 TEST(PatternMatcherTest, GetTupleElement) {
201 constexpr char kModuleStr[] = R"(
202 HloModule test_module
203
204 ENTRY while.v11 {
205 p0 = (f32[], f32[], f32[]) parameter(0)
206 ROOT gte = f32[] get-tuple-element(p0), index=1
207 })";
208 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
209
210 auto* root = hlo_module->entry_computation()->root_instruction();
211 EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
212 EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
213 EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
214 EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
215 EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
216 }
217
TEST(PatternMatcherTest,AnyOf)218 TEST(PatternMatcherTest, AnyOf) {
219 constexpr char kModuleStr[] = R"(
220 HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
221 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
222 auto* root = hlo_module->entry_computation()->root_instruction();
223
224 EXPECT_TRUE(
225 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
226 match::ConstantScalar(1))));
227 EXPECT_TRUE(
228 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
229 match::ConstantScalar(0))));
230 EXPECT_FALSE(
231 Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
232 match::ConstantScalar(2))));
233 }
234
TEST(PatternMatcherTest,ConstantScalar)235 TEST(PatternMatcherTest, ConstantScalar) {
236 using match::ConstantEffectiveScalar;
237 using match::ConstantScalar;
238 using match::Op;
239 using match::Tuple;
240
241 constexpr char kModuleStr[] = R"(
242 HloModule test_module
243 ENTRY test {
244 a = s32[] constant(1)
245 b = s32[1,1] constant({{2}})
246 c = s32[1,2] constant({{2,2}})
247 d = f32[] constant(1)
248 e = f32[] constant(1.25)
249 ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
250 })";
251 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
252 auto* root = hlo_module->entry_computation()->root_instruction();
253
254 const HloInstruction* a = root->operand(0);
255 const HloInstruction* b = root->operand(1);
256 const HloInstruction* c = root->operand(2);
257 const HloInstruction* d = root->operand(3);
258 const HloInstruction* e = root->operand(4);
259 EXPECT_TRUE(Match(a, ConstantScalar()));
260 EXPECT_TRUE(Match(a, ConstantScalar(1)));
261 EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
262 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
263 EXPECT_FALSE(Match(a, ConstantScalar(2)));
264 EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
265 EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
266 EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
267
268 EXPECT_FALSE(Match(b, ConstantScalar()));
269 EXPECT_FALSE(Match(b, ConstantScalar(2)));
270 EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
271 EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
272
273 EXPECT_FALSE(Match(c, ConstantScalar()));
274 EXPECT_FALSE(Match(c, ConstantScalar(2)));
275 EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
276 EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
277
278 EXPECT_TRUE(Match(d, ConstantScalar(1)));
279 EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
280 EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
281 EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
282
283 EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
284 EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
285 EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
286 EXPECT_FALSE(Match(e, ConstantScalar(1)));
287 EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
288
289 const HloInstruction* instr = nullptr;
290 EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
291 EXPECT_EQ(instr, a);
292
293 instr = nullptr;
294 EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
295 EXPECT_EQ(instr, a);
296
297 instr = nullptr;
298 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
299 EXPECT_EQ(instr, a);
300
301 instr = nullptr;
302 EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
303 EXPECT_EQ(instr, a);
304 }
305
TEST(PatternMatcherTest,MultiplyAnyOrder)306 TEST(PatternMatcherTest, MultiplyAnyOrder) {
307 using match::ConstantScalar;
308 using match::MultiplyAnyOrder;
309
310 constexpr char kModuleStr[] = R"(
311 HloModule test_module
312 ENTRY test {
313 lhs = f16[] constant(42)
314 rhs = f16[] constant(52)
315 ROOT multiply = f16[] multiply(lhs, rhs)
316 })";
317 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
318 auto* root = hlo_module->entry_computation()->root_instruction();
319 const HloInstruction* instr;
320
321 EXPECT_TRUE(Match(
322 root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
323 EXPECT_TRUE(Match(
324 root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
325
326 // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
327 // e.g. IsNonConstant() on it.
328 EXPECT_TRUE(Match(
329 root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
330 .IsNonConstant()));
331 EXPECT_TRUE(
332 Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
333 .IsNonConstant()));
334 }
335
TEST(PatternMatcherTest,AnyOfShortCircuit)336 TEST(PatternMatcherTest, AnyOfShortCircuit) {
337 using match::AnyOf;
338 using match::Multiply;
339 using match::Op;
340
341 constexpr char kModuleStr[] = R"(
342 HloModule test_module
343 ENTRY test {
344 lhs = f16[] constant(42)
345 rhs = f16[] constant(52)
346 ROOT multiply = f16[] multiply(lhs, rhs)
347 })";
348 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
349 auto* root = hlo_module->entry_computation()->root_instruction();
350
351 {
352 const HloInstruction* mul = nullptr;
353 const HloInstruction* any = nullptr;
354
355 ASSERT_TRUE(Match(
356 root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
357 EXPECT_NE(nullptr, mul);
358 EXPECT_EQ(nullptr, any);
359 }
360 {
361 const HloInstruction* mul = nullptr;
362 const HloInstruction* any = nullptr;
363
364 ASSERT_TRUE(Match(
365 root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
366 EXPECT_NE(nullptr, any);
367 EXPECT_EQ(nullptr, mul);
368 }
369 }
370
TEST(PatternMatcherTest,AllOf)371 TEST(PatternMatcherTest, AllOf) {
372 using match::AllOf;
373 using match::Broadcast;
374 using match::Constant;
375 using match::Op;
376
377 constexpr char kModuleStr[] = R"(
378 HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
379 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
380 auto* root = hlo_module->entry_computation()->root_instruction();
381
382 auto f16_scalar = ShapeUtil::MakeShape(F16, {});
383 auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
384 auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
385 auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
386 ASSERT_TRUE(Match(root, scalar_pattern));
387 ASSERT_TRUE(Match(root, f16_pattern));
388 ASSERT_TRUE(Match(root, f16_compatible_pattern));
389 EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
390 f16_compatible_pattern)));
391 EXPECT_TRUE(
392 Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
393 scalar_pattern)));
394 EXPECT_FALSE(
395 Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
396 EXPECT_FALSE(Match(
397 root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
398 EXPECT_FALSE(
399 Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
400 }
401
TEST(PatternMatcherTest,AllOfNoCaptureIfNotMatch)402 TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
403 using match::AllOf;
404 using match::Broadcast;
405 using match::Constant;
406 using match::Op;
407
408 constexpr char kModuleStr[] = R"(
409 HloModule test_module
410 ENTRY test {
411 ROOT v = f16[] constant(42)
412 })";
413 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
414 auto* root = hlo_module->entry_computation()->root_instruction();
415
416 const HloInstruction* constant = nullptr;
417 ASSERT_FALSE(
418 Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
419 EXPECT_EQ(nullptr, constant);
420 ASSERT_TRUE(Match(root, Constant(&constant)));
421 EXPECT_NE(nullptr, constant);
422 }
423
TEST(PatternMatcherTest,TestNoCapture)424 TEST(PatternMatcherTest, TestNoCapture) {
425 using match::Constant;
426
427 constexpr char kModuleStr[] = R"(
428 HloModule test_module
429 ENTRY test {
430 ROOT v = f16[] constant(42)
431 })";
432 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
433 auto* root = hlo_module->entry_computation()->root_instruction();
434
435 const HloInstruction* constant = nullptr;
436 ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
437 EXPECT_EQ(nullptr, constant);
438 }
439
TEST(PatternMatcherTest,TestCaptureMatchedSubPatternForAnyOf)440 TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
441 using match::Add;
442 using match::AddAnyOrder;
443 using match::AnyOf;
444 using match::Op;
445
446 constexpr char kModuleStr[] = R"(
447 HloModule test_module
448 ENTRY test {
449 u = f16[] parameter(0)
450 v = f16[] parameter(1)
451 ROOT add = f16[] add(u, v)
452 })";
453 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
454 auto* root = hlo_module->entry_computation()->root_instruction();
455
456 const HloInstruction* addend0 = nullptr;
457 const HloInstruction* addend1 = nullptr;
458 const HloInstruction* addend2 = nullptr;
459 auto add2_pattern = Add(Op(&addend0), Op(&addend1));
460 auto add3_pattern = AnyOf<HloInstruction>(
461 AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
462
463 ASSERT_TRUE(Match(root, add3_pattern));
464 EXPECT_NE(nullptr, addend0);
465 EXPECT_NE(nullptr, addend1);
466 EXPECT_EQ(nullptr, addend2);
467 }
468
TEST(PatternMatcherTest,TestConcat)469 TEST(PatternMatcherTest, TestConcat) {
470 using match::Concatenate;
471 using match::ConstantScalar;
472 using match::Op;
473 using match::Reshape;
474
475 constexpr char kModuleStr[] = R"(
476 HloModule test_module
477 ENTRY test {
478 c1 = u32[] constant(1)
479 c2 = u32[] constant(2)
480 c3 = u32[] constant(3)
481 c4 = u32[] constant(4)
482 r1 = u32[1] reshape(c1)
483 r2 = u32[1] reshape(c2)
484 r3 = u32[1] reshape(c3)
485 r4 = u32[1] reshape(c4)
486 ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
487 })";
488 TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
489 auto* root = hlo_module->entry_computation()->root_instruction();
490 ASSERT_TRUE(Match(
491 root,
492 Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
493 Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
494 ASSERT_FALSE(Match(
495 root,
496 Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)),
497 Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
498 ASSERT_FALSE(Match(
499 root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
500 Reshape(ConstantScalar(3)))));
501 ASSERT_FALSE(Match(
502 root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)),
503 Reshape(ConstantScalar(4)))));
504 }
505
506 template <typename Pattern>
Description(const Pattern & pattern)507 string Description(const Pattern& pattern) {
508 std::stringstream ss;
509 pattern.DescribeTo(&ss);
510 return ss.str();
511 }
512
513 template <typename Elem, typename Pattern>
Explanation(Elem * elem,const Pattern & pattern)514 string Explanation(Elem* elem, const Pattern& pattern) {
515 std::stringstream ss;
516 MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
517 Match(elem, pattern, options);
518 return ss.str();
519 }
520 template <typename Elem, typename Pattern>
Explanation(const std::unique_ptr<Elem> & elem,const Pattern & pattern)521 string Explanation(const std::unique_ptr<Elem>& elem, const Pattern& pattern) {
522 return Explanation(elem.get(), pattern);
523 }
524 template <typename Elem, typename Pattern>
Explanation(const Elem & elem,const Pattern & pattern)525 string Explanation(const Elem& elem, const Pattern& pattern) {
526 return Explanation(&elem, pattern);
527 }
528
529 // Helper macro for checking a pattern's description and the explanation printed
530 // when attempting to match (and presumably failing) on a given object.
531 //
532 // We use a macro rather than a function because we want good line numbers in
533 // errors. We use this rather than writing a helper that returns a pair of
534 // (description, explanation) and doing something like
535 //
536 // EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
537 //
538 // because EXPECT_EQ prints a unified diff if multiline string comparison fails,
539 // while EXPECT_THAT does not. This unified diff makes the errors much easier
540 // to read.
541 #define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc, \
542 expected_explanation) \
543 do { \
544 EXPECT_EQ(Description(pattern), (expected_desc)); \
545 EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
546 } while (0)
547
TEST(PatternMatcherTest,LayoutDescribeToAndExplain)548 TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
549 auto layout = LayoutUtil::MakeLayout({1, 2});
550 auto layout2 = LayoutUtil::MakeLayout({2, 2});
551
552 EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
553 "a layout", "Layout is null");
554 EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
555 "a layout equal to {1,2}",
556 "Layout {2,2} is not equal to expected {1,2}");
557 EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(),
558 "a layout with format SPARSE",
559 "Layout has format DENSE but expected SPARSE");
560 EXPECT_DESC_AND_EXPLANATION(layout,
561 m::Layout().EqualTo(&layout).WithSparseFormat(),
562 "a layout:\n"
563 " * equal to {1,2} AND\n"
564 " * with format SPARSE",
565 "Layout has format DENSE but expected SPARSE");
566 }
567
TEST(PatternMatcherTest,ShapeDescribeToAndExplain)568 TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
569 auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
570 auto layout = shape.layout();
571
572 EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
573 "a shape", "Shape is null");
574 EXPECT_DESC_AND_EXPLANATION(
575 ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
576 m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
577 "Shape not equal to f32[1,2]{0,1}\n"
578 "in f32[1,2]{1,0}");
579 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
580 m::Shape().CompatibleTo(&shape),
581 "a shape compatible with f32[1,2]",
582 "Shape not compatible with f32[1,2]\n"
583 "in f32[2,2]{1,0}");
584 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
585 "a shape with element type F16",
586 "Shape does not have element type F16\n"
587 "in f32[1,2]{0,1}");
588 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
589 "a shape that represents a scalar",
590 "Shape is not a scalar\n"
591 "in f32[1,2]{0,1}");
592 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
593 "a shape that represents an array",
594 "Shape is not an array\n"
595 "in ()");
596 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
597 "a shape that represents a tuple",
598 "Shape is not a tuple\n"
599 "in f32[1,2]{0,1}");
600 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
601 "a shape that is an effective scalar",
602 "Shape is not an effective scalar\n"
603 "in f32[1,2]{0,1}");
604 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
605 "a shape that has 42 dimensions",
606 "Shape does not have rank 42\n"
607 "in f32[1,2]{0,1}");
608 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
609 "a shape that is a scalar",
610 "Shape is not a scalar\n"
611 "in f32[1,2]{0,1}");
612 EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
613 "a shape:\n"
614 " * that has 1 dimension AND\n"
615 " * that represents an array",
616 "Shape does not have rank 1\n"
617 "in f32[1,2]{0,1}");
618 EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
619 m::Shape().IsArray().WithRank(1),
620 "a shape:\n"
621 " * that represents an array AND\n"
622 " * that has 1 dimension",
623 "Shape is not an array\n"
624 "in ()");
625 EXPECT_DESC_AND_EXPLANATION(
626 ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
627 m::Shape().WithLayoutEqualTo(&layout),
628 "a shape with\n a layout equal to {0,1}",
629 "Layout {1,0} is not equal to expected {0,1}\n"
630 "in f32[1,2]{1,0}");
631 EXPECT_DESC_AND_EXPLANATION(
632 shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()),
633 "a shape with\n a layout with format SPARSE",
634 "Layout has format DENSE but expected SPARSE\n"
635 "in f32[1,2]{0,1}");
636 EXPECT_DESC_AND_EXPLANATION(shape,
637 m::Shape().WithSubshapeEqualTo({10}, &shape),
638 "a shape with subshape at index {10} which is\n"
639 " a shape equal to f32[1,2]{0,1}",
640 "No subshape at {10}\n"
641 "in f32[1,2]{0,1}");
642 EXPECT_DESC_AND_EXPLANATION(
643 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
644 m::Shape().WithSubshapeEqualTo({0}, &shape),
645 "a shape with subshape at index {0} which is\n"
646 " a shape equal to f32[1,2]{0,1}",
647 "Shape not equal to f32[1,2]{0,1}\n"
648 "in f32[2,2]{1,0}\n"
649 "in subshape at {0}\n"
650 "in (f32[2,2])");
651 EXPECT_DESC_AND_EXPLANATION(shape,
652 m::Shape().WithSubshapeCompatibleTo({10}, &shape),
653 "a shape with subshape at index {10} which is\n"
654 " a shape compatible with f32[1,2]",
655 "No subshape at {10}\n"
656 "in f32[1,2]{0,1}");
657 EXPECT_DESC_AND_EXPLANATION(
658 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
659 m::Shape().WithSubshapeCompatibleTo({0}, &shape),
660 "a shape with subshape at index {0} which is\n"
661 " a shape compatible with f32[1,2]",
662 "Shape not compatible with f32[1,2]\n"
663 "in f32[2,2]{1,0}\n"
664 "in subshape at {0}\n"
665 "in (f32[2,2])");
666 EXPECT_DESC_AND_EXPLANATION(
667 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
668 m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
669 "a shape with subshape at index {0,0} which is\n"
670 " a shape that represents a scalar",
671 "Shape is not a scalar\n"
672 "in f32[1,2]{0,1}\n"
673 "in subshape at {0,0}\n"
674 "in ((f32[1,2]))");
675 }
676
SetName(absl::string_view name,std::unique_ptr<HloInstruction> instr)677 std::unique_ptr<HloInstruction> SetName(absl::string_view name,
678 std::unique_ptr<HloInstruction> instr) {
679 instr->SetAndSanitizeName(string(name));
680 return instr;
681 }
682
TEST(PatternMatcherTest,HloInstructionDescribeToAndExplain)683 TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
684 std::unique_ptr<HloInstruction> iota =
685 SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
686 /*iota_dimension=*/0));
687 std::unique_ptr<HloInstruction> constant =
688 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
689
690 EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
691 m::Op(), "an HloInstruction",
692 "HloInstruction* is null");
693 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
694 "an HloInstruction named \"foo\"",
695 "HloInstruction not named \"foo\"\n"
696 "in i = s32[42]{0} iota(), iota_dimension=0");
697 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
698 "an HloInstruction with opcode add",
699 "HloInstruction doesn't have opcode add\n"
700 "in i = s32[42]{0} iota(), iota_dimension=0");
701 EXPECT_DESC_AND_EXPLANATION(
702 constant, m::Op().IsNonConstant(),
703 "an HloInstruction with any opcode other than constant",
704 "HloInstruction has opcode constant, expected anything else\n"
705 "in c = s32[] constant(0)");
706 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
707 "an HloInstruction with 42 operands",
708 "HloInstruction doesn't have 42 operands\n"
709 "in i = s32[42]{0} iota(), iota_dimension=0");
710 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
711 "an HloInstruction outputting\n"
712 " a shape that represents a tuple",
713 "Shape is not a tuple\n"
714 "in s32[42]{0}\n"
715 "in output shape\n"
716 "in i = s32[42]{0} iota(), iota_dimension=0");
717 EXPECT_DESC_AND_EXPLANATION(
718 iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
719 "an HloInstruction with operand 2 which is:\n"
720 " an HloInstruction with opcode add",
721 "desired operand index 2 is out of bounds\n"
722 "in i = s32[42]{0} iota(), iota_dimension=0");
723
724 EXPECT_DESC_AND_EXPLANATION(
725 SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
726 HloOpcode::kAdd, constant.get(),
727 constant.get())),
728 m::Op().WithOperand(1, m::Op().IsNonConstant()),
729 "an HloInstruction with operand 1 which is:\n"
730 " an HloInstruction with any opcode other than constant",
731 "HloInstruction has opcode constant, expected anything else\n"
732 "in c = s32[] constant(0)\n"
733 "in operand 1\n"
734 "in a = s32[] add(s32[] c, s32[] c)");
735 EXPECT_DESC_AND_EXPLANATION(
736 iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
737 "an HloInstruction with fusion kind kLoop",
738 "HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
739 "in i = s32[42]{0} iota(), iota_dimension=0");
740 EXPECT_DESC_AND_EXPLANATION(
741 iota, m::Op().WithTupleIndex(42),
742 "an HloInstruction which is a GTE with index 42",
743 "HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
744 "in i = s32[42]{0} iota(), iota_dimension=0");
745 EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
746 "an HloInstruction which is a constant scalar",
747 "HloInstruction is not a constant\n"
748 "in i = s32[42]{0} iota(), iota_dimension=0");
749 EXPECT_DESC_AND_EXPLANATION(
750 SetName("c", HloInstruction::CreateConstant(
751 LiteralUtil::CreateR1<int>({1, 2}))),
752 m::Op().IsConstantEffectiveScalar(),
753 "an HloInstruction which is a constant effective scalar",
754 "HloInstruction is not an effective scalar\n"
755 "in c = s32[2]{0} constant({1, 2})");
756 EXPECT_DESC_AND_EXPLANATION(
757 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
758 m::Op().IsConstantScalar(42),
759 "an HloInstruction which is a constant scalar with value 42",
760 "HloInstruction's constant value 10 did not match expected value 42\n"
761 "in c = s32[] constant(10)");
762 EXPECT_DESC_AND_EXPLANATION(
763 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
764 m::Op().IsConstantEffectiveScalar(1.25),
765 "an HloInstruction which is a constant effective scalar with value 1.25",
766 "HloInstruction's constant value 2.25 did not match expected value 1.25\n"
767 "in c = f64[] constant(2.25)");
768 EXPECT_DESC_AND_EXPLANATION(
769 constant, m::Op().Is(iota.get()),
770 absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()),
771 " (i = s32[42]{0} iota(), iota_dimension=0)"),
772 absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
773 absl::Hex(iota.get()),
774 " (i = s32[42]{0} iota(), iota_dimension=0)\n"
775 "in c = s32[] constant(0)"));
776 }
777
TEST(PatternMatcherTest,HloInstructionMatcherAnyOrderDescribeTo)778 TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
779 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
780 EXPECT_DESC_AND_EXPLANATION(
781 SetName("a", HloInstruction::CreateBinary(
782 scalar_s32, HloOpcode::kAdd,
783 SetName("b", HloInstruction::CreateConstant(
784 LiteralUtil::CreateR0(0)))
785 .get(),
786 SetName("c", HloInstruction::CreateConstant(
787 LiteralUtil::CreateR0(0)))
788 .get())),
789 m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
790 "an HloInstruction:\n"
791 " * with opcode add AND\n"
792 " * with two operands in either order:\n"
793 " - an HloInstruction named \"b\"\n"
794 " - an HloInstruction named \"bar\"",
795 "HloInstruction's operands (ignoring order) did not match second "
796 "matcher. Specifically,\n"
797 " - an HloInstruction named \"bar\"\n"
798 "does not match LHS:\n"
799 " - HloInstruction not named \"bar\"\n"
800 " in b = s32[] constant(0)\n"
801 "does not match RHS:\n"
802 " - HloInstruction not named \"bar\"\n"
803 " in c = s32[] constant(0)\n"
804 "in a = s32[] add(s32[] b, s32[] c)");
805
806 EXPECT_DESC_AND_EXPLANATION(
807 SetName("a",
808 HloInstruction::CreateBinary(
809 scalar_s32, HloOpcode::kAdd,
810 HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
811 SetName("c", HloInstruction::CreateConstant(
812 LiteralUtil::CreateR0(0)))
813 .get())),
814 m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
815 "an HloInstruction:\n"
816 " * with opcode add AND\n"
817 " * with two operands in either order:\n"
818 " - an HloInstruction which is a constant scalar\n"
819 " - an HloInstruction with opcode constant",
820 "HloInstruction's LHS operand did not match either of the two matchers. "
821 "Specifically,\n"
822 " - an HloInstruction which is a constant scalar\n"
823 "does not match LHS:\n"
824 " - HloInstruction is not a constant\n"
825 " in p = s32[] parameter(0)\n"
826 "and\n"
827 " - an HloInstruction with opcode constant\n"
828 "does not match LHS:\n"
829 " - HloInstruction doesn't have opcode constant\n"
830 " in p = s32[] parameter(0)\n"
831 "in a = s32[] add(s32[] p, s32[] c)");
832 }
833
TEST(PatternMatcherTest,AnyOfMatcherDescribeToAndExplain)834 TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
835 EXPECT_DESC_AND_EXPLANATION(
836 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
837 m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
838 m::Op().WithName("bar")),
839 "any of:\n"
840 " - an HloInstruction named \"foo\" OR\n"
841 " - an HloInstruction named \"bar\"",
842 "None of the following matchers succeeded:\n"
843 "Matcher #1\n"
844 " - an HloInstruction named \"foo\"\n"
845 "failed with\n"
846 " - HloInstruction not named \"foo\"\n"
847 " in c = s32[] constant(0)\n"
848 "Matcher #2\n"
849 " - an HloInstruction named \"bar\"\n"
850 "failed with\n"
851 " - HloInstruction not named \"bar\"\n"
852 " in c = s32[] constant(0)");
853 }
854
TEST(PatternMatcherTest,Parameter)855 TEST(PatternMatcherTest, Parameter) {
856 auto param =
857 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
858 auto non_param =
859 SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
860 EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
861 EXPECT_TRUE(Match(param.get(), m::Parameter()));
862 EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
863 EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
864 EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
865
866 EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
867 "an HloInstruction:\n"
868 " * with opcode parameter AND\n"
869 " * which is parameter 1",
870 "HloInstruction doesn't have opcode parameter\n"
871 "in c = s32[] constant(0)");
872 EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
873 0, ShapeUtil::MakeShape(F32, {}), "p0"),
874 m::Parameter(1)),
875 "HloInstruction is not parameter 1\n"
876 "in p0 = f32[] parameter(0)");
877 }
878
TEST(PatternMatcherTest,OneUseAndOneUser)879 TEST(PatternMatcherTest, OneUseAndOneUser) {
880 auto param =
881 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
882
883 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
884 EXPECT_DESC_AND_EXPLANATION(
885 param, m::Op().WithOneUse(),
886 "an HloInstruction which has exactly one use",
887 "HloInstruction has 0 users, but expected exactly one.\n"
888 "in p0 = f32[] parameter(0)");
889
890 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
891 EXPECT_DESC_AND_EXPLANATION(
892 param, m::Op().WithOneUser(),
893 "an HloInstruction which has exactly one user (but possibly is used "
894 "multiple times by that instruction)",
895 "HloInstruction has 0 users, but expected exactly one.\n"
896 "in p0 = f32[] parameter(0)");
897
898 {
899 auto reshape =
900 SetName("r", HloInstruction::CreateReshape(
901 ShapeUtil::MakeShape(F32, {1}), param.get()));
902 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse()));
903 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
904
905 auto reshape1 =
906 SetName("r1", HloInstruction::CreateReshape(
907 ShapeUtil::MakeShape(F32, {1}), param.get()));
908 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
909 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
910
911 const char* kMultipleUserExplanation =
912 "HloInstruction has 2 users, but expected exactly one.\n"
913 "All users:\n"
914 " - r = f32[1]{0} reshape(f32[] p0)\n"
915 " - r1 = f32[1]{0} reshape(f32[] p0)\n"
916 "in p0 = f32[] parameter(0)";
917 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
918 kMultipleUserExplanation);
919 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()),
920 kMultipleUserExplanation);
921 }
922
923 auto add = SetName("add", HloInstruction::CreateBinary(
924 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd,
925 param.get(), param.get()));
926 EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
927 EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
928 EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
929 "HloInstruction is used 2 times by its user, but is expected to be "
930 "used just once: add = f32[] add(f32[] p0, f32[] p0)\n"
931 "in p0 = f32[] parameter(0)");
932 }
933
TEST(HloMatchersTest,Comparison)934 TEST(HloMatchersTest, Comparison) {
935 auto shape = ShapeUtil::MakeShape(F32, {1});
936 auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
937 auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
938 auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
939 ComparisonDirection::kEq);
940 auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
941 ComparisonDirection::kNe);
942 auto add =
943 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
944 auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
945 ComparisonDirection::kLe);
946
947 EXPECT_TRUE(Match(eq.get(), m::Compare()));
948 EXPECT_TRUE(Match(eq.get(), m::Eq()));
949 EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
950 EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
951 EXPECT_TRUE(Match(ne.get(), m::Compare()));
952 EXPECT_TRUE(Match(ne.get(), m::Ne()));
953 EXPECT_TRUE(Match(
954 le.get(),
955 m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
956 EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
957 m::Add(m::Parameter(0), m::Parameter(1)))));
958
959 EXPECT_FALSE(Match(eq.get(), m::Add()));
960 EXPECT_FALSE(Match(eq.get(), m::Ne()));
961 EXPECT_FALSE(
962 Match(le.get(),
963 m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
964 EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
965 EXPECT_DESC_AND_EXPLANATION(
966 eq, m::Ne().WithOneUser(),
967 "an HloInstruction:\n"
968 " * with opcode compare AND\n"
969 " * which has comparison direction NE AND\n"
970 " * which has exactly one user (but possibly is used "
971 "multiple times by that instruction)",
972 "HloInstruction is not comparison NE\n"
973 "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
974 "direction=EQ");
975 }
976
977 } // namespace
978 } // namespace xla
979