1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
17 #include "absl/strings/str_cat.h"
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace xla {
25 namespace {
26 
27 namespace m = match;
28 
TEST(PatternMatcherTest,AddOp)29 TEST(PatternMatcherTest, AddOp) {
30   constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
31     ENTRY %two_plus_two_computation () -> f32[] {
32       %two = f32[] constant(2)
33       ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
34     }
35   )";
36   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
37 
38   const HloInstruction* matched_inst;
39   HloInstruction* matched_operand;
40   Shape* matched_shape;
41   Layout* matched_layout;
42 
43   ASSERT_TRUE(Match(
44       hlo_module->entry_computation()->root_instruction(),
45       match::Op(&matched_inst)
46           .WithName("two_plus_two")
47           .WithOpcode(HloOpcode::kAdd)
48           .WithShape(
49               match::Shape(&matched_shape)
50                   .WithLayout(match::Layout(&matched_layout).WithDenseFormat()))
51           .WithOperand(
52               0,
53               match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
54   ASSERT_NE(matched_inst, nullptr);
55   EXPECT_EQ(matched_inst->name(), "two_plus_two");
56   EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
57 
58   EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
59                     match::Add(match::Constant(), match::Constant())));
60 
61   EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
62                      match::Op().WithName("bad_name")));
63   matched_inst = nullptr;
64   EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
65                      match::Multiply(&matched_inst, match::Op(), match::Op())));
66 }
67 
TEST(PatternMatcherTest,ScalarShape)68 TEST(PatternMatcherTest, ScalarShape) {
69   auto scalar_shape = ShapeUtil::MakeShape(F32, {});
70   Shape* matched_shape;
71   EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
72   EXPECT_EQ(matched_shape, &scalar_shape);
73   EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
74   EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray()));
75   EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
76   EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
77   EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
78   EXPECT_FALSE(Match(
79       &scalar_shape,
80       match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
81 }
82 
TEST(PatternMatcherTest,DenseArrayShape)83 TEST(PatternMatcherTest, DenseArrayShape) {
84   auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
85   Shape* matched_shape;
86   EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
87   EXPECT_EQ(matched_shape, &array_shape);
88   EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray()));
89   EXPECT_FALSE(Match(&array_shape, match::Shape().IsSparseArray()));
90   EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
91   EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
92   EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
93   EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
94   EXPECT_FALSE(
95       Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
96   Layout* matched_layout;
97   EXPECT_FALSE(Match(&array_shape,
98                      match::Shape().WithLayout(
99                          match::Layout(&matched_layout).WithSparseFormat())));
100   EXPECT_TRUE(Match(&array_shape,
101                     match::Shape().WithLayout(
102                         match::Layout(&matched_layout).WithDenseFormat())));
103   EXPECT_EQ(matched_layout, &array_shape.layout());
104 }
105 
TEST(PatternMatcherTest,SparseArrayShape)106 TEST(PatternMatcherTest, SparseArrayShape) {
107   auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10);
108   Shape* matched_shape;
109   EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
110   EXPECT_EQ(matched_shape, &array_shape);
111   EXPECT_FALSE(Match(&array_shape, match::Shape().IsDenseArray()));
112   EXPECT_TRUE(Match(&array_shape, match::Shape().IsSparseArray()));
113   EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
114   EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
115   EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
116   EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
117   EXPECT_FALSE(
118       Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
119   Layout* matched_layout;
120   EXPECT_FALSE(Match(&array_shape,
121                      match::Shape().WithLayout(
122                          match::Layout(&matched_layout).WithDenseFormat())));
123   EXPECT_TRUE(Match(&array_shape,
124                     match::Shape().WithLayout(
125                         match::Layout(&matched_layout).WithSparseFormat())));
126   EXPECT_EQ(matched_layout, &array_shape.layout());
127 }
128 
TEST(PatternMatcherTest,TupleShape)129 TEST(PatternMatcherTest, TupleShape) {
130   auto tuple_shape = ShapeUtil::MakeTupleShape({
131       ShapeUtil::MakeShape(F32, {1, 2, 3}),
132       ShapeUtil::MakeShape(S32, {4, 5}),
133   });
134   EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
135   EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
136   EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
137 
138   Shape* subshape;
139   ASSERT_TRUE(Match(
140       &tuple_shape,
141       match::Shape().WithSubshape(
142           {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
143   ASSERT_NE(subshape, nullptr);
144   EXPECT_TRUE(
145       ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
146   EXPECT_TRUE(Match(&tuple_shape,
147                     match::Shape().WithSubshape(
148                         {0}, match::Shape().EqualTo(
149                                  &ShapeUtil::GetSubshape(tuple_shape, {0})))));
150   EXPECT_FALSE(Match(&tuple_shape,
151                      match::Shape().WithSubshape(
152                          {0}, match::Shape().EqualTo(
153                                   &ShapeUtil::GetSubshape(tuple_shape, {1})))));
154 
155   ASSERT_TRUE(Match(
156       &tuple_shape,
157       match::Shape().WithSubshape(
158           {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
159   ASSERT_NE(subshape, nullptr);
160   EXPECT_TRUE(
161       ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
162   EXPECT_TRUE(Match(&tuple_shape,
163                     match::Shape().WithSubshape(
164                         {1}, match::Shape().EqualTo(
165                                  &ShapeUtil::GetSubshape(tuple_shape, {1})))));
166   EXPECT_FALSE(Match(&tuple_shape,
167                      match::Shape().WithSubshape(
168                          {1}, match::Shape().EqualTo(
169                                   &ShapeUtil::GetSubshape(tuple_shape, {0})))));
170 
171   EXPECT_FALSE(
172       Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
173   EXPECT_FALSE(
174       Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
175 }
176 
TEST(PatternMatcherTest,FusionKind)177 TEST(PatternMatcherTest, FusionKind) {
178   constexpr char kModuleStr[] = R"(
179     HloModule test_module
180 
181     fused_computation {
182       ROOT fp0 = f32[] parameter(0)
183     }
184 
185     ENTRY while.v11 {
186       p0 = f32[] parameter(0)
187       ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
188     })";
189   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
190 
191   auto* root = hlo_module->entry_computation()->root_instruction();
192   EXPECT_TRUE(Match(
193       root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop)));
194   EXPECT_FALSE(Match(
195       root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput)));
196   EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind(
197                                            HloInstruction::FusionKind::kLoop)));
198 }
199 
TEST(PatternMatcherTest,GetTupleElement)200 TEST(PatternMatcherTest, GetTupleElement) {
201   constexpr char kModuleStr[] = R"(
202     HloModule test_module
203 
204     ENTRY while.v11 {
205       p0 = (f32[], f32[], f32[]) parameter(0)
206       ROOT gte = f32[] get-tuple-element(p0), index=1
207     })";
208   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
209 
210   auto* root = hlo_module->entry_computation()->root_instruction();
211   EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
212   EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
213   EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
214   EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
215   EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
216 }
217 
TEST(PatternMatcherTest,AnyOf)218 TEST(PatternMatcherTest, AnyOf) {
219   constexpr char kModuleStr[] = R"(
220     HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
221   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
222   auto* root = hlo_module->entry_computation()->root_instruction();
223 
224   EXPECT_TRUE(
225       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
226                                                match::ConstantScalar(1))));
227   EXPECT_TRUE(
228       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
229                                                match::ConstantScalar(0))));
230   EXPECT_FALSE(
231       Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
232                                                match::ConstantScalar(2))));
233 }
234 
TEST(PatternMatcherTest,ConstantScalar)235 TEST(PatternMatcherTest, ConstantScalar) {
236   using match::ConstantEffectiveScalar;
237   using match::ConstantScalar;
238   using match::Op;
239   using match::Tuple;
240 
241   constexpr char kModuleStr[] = R"(
242     HloModule test_module
243     ENTRY test {
244       a = s32[] constant(1)
245       b = s32[1,1] constant({{2}})
246       c = s32[1,2] constant({{2,2}})
247       d = f32[] constant(1)
248       e = f32[] constant(1.25)
249       ROOT tuple = (s32[], s32[1,1], s32[1,2], f32[], f32[]) tuple(a,b,c,d,e)
250     })";
251   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
252   auto* root = hlo_module->entry_computation()->root_instruction();
253 
254   const HloInstruction* a = root->operand(0);
255   const HloInstruction* b = root->operand(1);
256   const HloInstruction* c = root->operand(2);
257   const HloInstruction* d = root->operand(3);
258   const HloInstruction* e = root->operand(4);
259   EXPECT_TRUE(Match(a, ConstantScalar()));
260   EXPECT_TRUE(Match(a, ConstantScalar(1)));
261   EXPECT_TRUE(Match(a, ConstantEffectiveScalar()));
262   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(1)));
263   EXPECT_FALSE(Match(a, ConstantScalar(2)));
264   EXPECT_FALSE(Match(a, ConstantScalar(2.01)));
265   EXPECT_FALSE(Match(a, ConstantEffectiveScalar(2)));
266   EXPECT_FALSE(Match(a, ConstantEffectiveScalar(1.01)));
267 
268   EXPECT_FALSE(Match(b, ConstantScalar()));
269   EXPECT_FALSE(Match(b, ConstantScalar(2)));
270   EXPECT_TRUE(Match(b, ConstantEffectiveScalar()));
271   EXPECT_TRUE(Match(b, ConstantEffectiveScalar(2)));
272 
273   EXPECT_FALSE(Match(c, ConstantScalar()));
274   EXPECT_FALSE(Match(c, ConstantScalar(2)));
275   EXPECT_FALSE(Match(c, ConstantEffectiveScalar()));
276   EXPECT_FALSE(Match(c, ConstantEffectiveScalar(2)));
277 
278   EXPECT_TRUE(Match(d, ConstantScalar(1)));
279   EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1)));
280   EXPECT_TRUE(Match(d, ConstantScalar(1.0)));
281   EXPECT_TRUE(Match(d, ConstantEffectiveScalar(1.0)));
282 
283   EXPECT_TRUE(Match(e, ConstantScalar(1.25f)));
284   EXPECT_TRUE(Match(e, ConstantScalar(1.25)));
285   EXPECT_TRUE(Match(e, ConstantEffectiveScalar(1.25)));
286   EXPECT_FALSE(Match(e, ConstantScalar(1)));
287   EXPECT_FALSE(Match(e, ConstantEffectiveScalar(1)));
288 
289   const HloInstruction* instr = nullptr;
290   EXPECT_TRUE(Match(a, ConstantScalar(&instr)));
291   EXPECT_EQ(instr, a);
292 
293   instr = nullptr;
294   EXPECT_TRUE(Match(a, ConstantScalar(&instr, 1)));
295   EXPECT_EQ(instr, a);
296 
297   instr = nullptr;
298   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr)));
299   EXPECT_EQ(instr, a);
300 
301   instr = nullptr;
302   EXPECT_TRUE(Match(a, ConstantEffectiveScalar(&instr, 1)));
303   EXPECT_EQ(instr, a);
304 }
305 
TEST(PatternMatcherTest,MultiplyAnyOrder)306 TEST(PatternMatcherTest, MultiplyAnyOrder) {
307   using match::ConstantScalar;
308   using match::MultiplyAnyOrder;
309 
310   constexpr char kModuleStr[] = R"(
311     HloModule test_module
312     ENTRY test {
313       lhs = f16[] constant(42)
314       rhs = f16[] constant(52)
315       ROOT multiply = f16[] multiply(lhs, rhs)
316     })";
317   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
318   auto* root = hlo_module->entry_computation()->root_instruction();
319   const HloInstruction* instr;
320 
321   EXPECT_TRUE(Match(
322       root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
323   EXPECT_TRUE(Match(
324       root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
325 
326   // Check that MultiplyAnyOrder exposes the same API as Op(), so we can call
327   // e.g. IsNonConstant() on it.
328   EXPECT_TRUE(Match(
329       root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))
330                 .IsNonConstant()));
331   EXPECT_TRUE(
332       Match(root, MultiplyAnyOrder(ConstantScalar(42), ConstantScalar(52))
333                       .IsNonConstant()));
334 }
335 
TEST(PatternMatcherTest,AnyOfShortCircuit)336 TEST(PatternMatcherTest, AnyOfShortCircuit) {
337   using match::AnyOf;
338   using match::Multiply;
339   using match::Op;
340 
341   constexpr char kModuleStr[] = R"(
342     HloModule test_module
343     ENTRY test {
344       lhs = f16[] constant(42)
345       rhs = f16[] constant(52)
346       ROOT multiply = f16[] multiply(lhs, rhs)
347     })";
348   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
349   auto* root = hlo_module->entry_computation()->root_instruction();
350 
351   {
352     const HloInstruction* mul = nullptr;
353     const HloInstruction* any = nullptr;
354 
355     ASSERT_TRUE(Match(
356         root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
357     EXPECT_NE(nullptr, mul);
358     EXPECT_EQ(nullptr, any);
359   }
360   {
361     const HloInstruction* mul = nullptr;
362     const HloInstruction* any = nullptr;
363 
364     ASSERT_TRUE(Match(
365         root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
366     EXPECT_NE(nullptr, any);
367     EXPECT_EQ(nullptr, mul);
368   }
369 }
370 
TEST(PatternMatcherTest,AllOf)371 TEST(PatternMatcherTest, AllOf) {
372   using match::AllOf;
373   using match::Broadcast;
374   using match::Constant;
375   using match::Op;
376 
377   constexpr char kModuleStr[] = R"(
378     HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
379   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
380   auto* root = hlo_module->entry_computation()->root_instruction();
381 
382   auto f16_scalar = ShapeUtil::MakeShape(F16, {});
383   auto f16_pattern = Constant().WithShapeEqualTo(&f16_scalar);
384   auto f16_compatible_pattern = Constant().WithShapeCompatibleTo(&f16_scalar);
385   auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
386   ASSERT_TRUE(Match(root, scalar_pattern));
387   ASSERT_TRUE(Match(root, f16_pattern));
388   ASSERT_TRUE(Match(root, f16_compatible_pattern));
389   EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern,
390                                                 f16_compatible_pattern)));
391   EXPECT_TRUE(
392       Match(root, AllOf<HloInstruction>(f16_pattern, f16_compatible_pattern,
393                                         scalar_pattern)));
394   EXPECT_FALSE(
395       Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
396   EXPECT_FALSE(Match(
397       root, AllOf<HloInstruction>(Broadcast(Op()), f16_compatible_pattern)));
398   EXPECT_FALSE(
399       Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
400 }
401 
TEST(PatternMatcherTest,AllOfNoCaptureIfNotMatch)402 TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
403   using match::AllOf;
404   using match::Broadcast;
405   using match::Constant;
406   using match::Op;
407 
408   constexpr char kModuleStr[] = R"(
409     HloModule test_module
410     ENTRY test {
411       ROOT v = f16[] constant(42)
412     })";
413   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
414   auto* root = hlo_module->entry_computation()->root_instruction();
415 
416   const HloInstruction* constant = nullptr;
417   ASSERT_FALSE(
418       Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
419   EXPECT_EQ(nullptr, constant);
420   ASSERT_TRUE(Match(root, Constant(&constant)));
421   EXPECT_NE(nullptr, constant);
422 }
423 
TEST(PatternMatcherTest,TestNoCapture)424 TEST(PatternMatcherTest, TestNoCapture) {
425   using match::Constant;
426 
427   constexpr char kModuleStr[] = R"(
428     HloModule test_module
429     ENTRY test {
430       ROOT v = f16[] constant(42)
431     })";
432   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
433   auto* root = hlo_module->entry_computation()->root_instruction();
434 
435   const HloInstruction* constant = nullptr;
436   ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
437   EXPECT_EQ(nullptr, constant);
438 }
439 
TEST(PatternMatcherTest,TestCaptureMatchedSubPatternForAnyOf)440 TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
441   using match::Add;
442   using match::AddAnyOrder;
443   using match::AnyOf;
444   using match::Op;
445 
446   constexpr char kModuleStr[] = R"(
447     HloModule test_module
448     ENTRY test {
449       u = f16[] parameter(0)
450       v = f16[] parameter(1)
451       ROOT add = f16[] add(u, v)
452     })";
453   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
454   auto* root = hlo_module->entry_computation()->root_instruction();
455 
456   const HloInstruction* addend0 = nullptr;
457   const HloInstruction* addend1 = nullptr;
458   const HloInstruction* addend2 = nullptr;
459   auto add2_pattern = Add(Op(&addend0), Op(&addend1));
460   auto add3_pattern = AnyOf<HloInstruction>(
461       AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
462 
463   ASSERT_TRUE(Match(root, add3_pattern));
464   EXPECT_NE(nullptr, addend0);
465   EXPECT_NE(nullptr, addend1);
466   EXPECT_EQ(nullptr, addend2);
467 }
468 
TEST(PatternMatcherTest,TestConcat)469 TEST(PatternMatcherTest, TestConcat) {
470   using match::Concatenate;
471   using match::ConstantScalar;
472   using match::Op;
473   using match::Reshape;
474 
475   constexpr char kModuleStr[] = R"(
476     HloModule test_module
477     ENTRY test {
478       c1 = u32[] constant(1)
479       c2 = u32[] constant(2)
480       c3 = u32[] constant(3)
481       c4 = u32[] constant(4)
482       r1 = u32[1] reshape(c1)
483       r2 = u32[1] reshape(c2)
484       r3 = u32[1] reshape(c3)
485       r4 = u32[1] reshape(c4)
486       ROOT concat = u32[4] concatenate(r1, r2, r3, r4), dimensions={0}
487     })";
488   TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
489   auto* root = hlo_module->entry_computation()->root_instruction();
490   ASSERT_TRUE(Match(
491       root,
492       Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
493                   Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
494   ASSERT_FALSE(Match(
495       root,
496       Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(1)),
497                   Reshape(ConstantScalar(3)), Reshape(ConstantScalar(4)))));
498   ASSERT_FALSE(Match(
499       root, Concatenate(Reshape(ConstantScalar(1)), Reshape(ConstantScalar(2)),
500                         Reshape(ConstantScalar(3)))));
501   ASSERT_FALSE(Match(
502       root, Concatenate(Reshape(ConstantScalar(2)), Reshape(ConstantScalar(3)),
503                         Reshape(ConstantScalar(4)))));
504 }
505 
506 template <typename Pattern>
Description(const Pattern & pattern)507 string Description(const Pattern& pattern) {
508   std::stringstream ss;
509   pattern.DescribeTo(&ss);
510   return ss.str();
511 }
512 
513 template <typename Elem, typename Pattern>
Explanation(Elem * elem,const Pattern & pattern)514 string Explanation(Elem* elem, const Pattern& pattern) {
515   std::stringstream ss;
516   MatchOption options{/*.capture=*/true, /*.explain_os=*/&ss};
517   Match(elem, pattern, options);
518   return ss.str();
519 }
520 template <typename Elem, typename Pattern>
Explanation(const std::unique_ptr<Elem> & elem,const Pattern & pattern)521 string Explanation(const std::unique_ptr<Elem>& elem, const Pattern& pattern) {
522   return Explanation(elem.get(), pattern);
523 }
524 template <typename Elem, typename Pattern>
Explanation(const Elem & elem,const Pattern & pattern)525 string Explanation(const Elem& elem, const Pattern& pattern) {
526   return Explanation(&elem, pattern);
527 }
528 
529 // Helper macro for checking a pattern's description and the explanation printed
530 // when attempting to match (and presumably failing) on a given object.
531 //
532 // We use a macro rather than a function because we want good line numbers in
533 // errors.  We use this rather than writing a helper that returns a pair of
534 // (description, explanation) and doing something like
535 //
536 //   EXPECT_THAT(DescAndExplanation(...), ::testing::Pair(..., ...));
537 //
538 // because EXPECT_EQ prints a unified diff if multiline string comparison fails,
539 // while EXPECT_THAT does not.  This unified diff makes the errors much easier
540 // to read.
541 #define EXPECT_DESC_AND_EXPLANATION(elem, pattern, expected_desc,    \
542                                     expected_explanation)            \
543   do {                                                               \
544     EXPECT_EQ(Description(pattern), (expected_desc));                \
545     EXPECT_EQ(Explanation((elem), (pattern)), expected_explanation); \
546   } while (0)
547 
TEST(PatternMatcherTest,LayoutDescribeToAndExplain)548 TEST(PatternMatcherTest, LayoutDescribeToAndExplain) {
549   auto layout = LayoutUtil::MakeLayout({1, 2});
550   auto layout2 = LayoutUtil::MakeLayout({2, 2});
551 
552   EXPECT_DESC_AND_EXPLANATION(static_cast<const Layout*>(nullptr), m::Layout(),
553                               "a layout", "Layout is null");
554   EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().EqualTo(&layout),
555                               "a layout equal to {1,2}",
556                               "Layout {2,2} is not equal to expected {1,2}");
557   EXPECT_DESC_AND_EXPLANATION(layout2, m::Layout().WithSparseFormat(),
558                               "a layout with format SPARSE",
559                               "Layout has format DENSE but expected SPARSE");
560   EXPECT_DESC_AND_EXPLANATION(layout,
561                               m::Layout().EqualTo(&layout).WithSparseFormat(),
562                               "a layout:\n"
563                               " * equal to {1,2} AND\n"
564                               " * with format SPARSE",
565                               "Layout has format DENSE but expected SPARSE");
566 }
567 
TEST(PatternMatcherTest,ShapeDescribeToAndExplain)568 TEST(PatternMatcherTest, ShapeDescribeToAndExplain) {
569   auto shape = ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {0, 1});
570   auto layout = shape.layout();
571 
572   EXPECT_DESC_AND_EXPLANATION(static_cast<const Shape*>(nullptr), m::Shape(),
573                               "a shape", "Shape is null");
574   EXPECT_DESC_AND_EXPLANATION(
575       ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
576       m::Shape().EqualTo(&shape), "a shape equal to f32[1,2]{0,1}",
577       "Shape not equal to f32[1,2]{0,1}\n"
578       "in f32[1,2]{1,0}");
579   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeShape(F32, {2, 2}),
580                               m::Shape().CompatibleTo(&shape),
581                               "a shape compatible with f32[1,2]",
582                               "Shape not compatible with f32[1,2]\n"
583                               "in f32[2,2]{1,0}");
584   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithElementType(F16),
585                               "a shape with element type F16",
586                               "Shape does not have element type F16\n"
587                               "in f32[1,2]{0,1}");
588   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsScalar(),
589                               "a shape that represents a scalar",
590                               "Shape is not a scalar\n"
591                               "in f32[1,2]{0,1}");
592   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(), m::Shape().IsArray(),
593                               "a shape that represents an array",
594                               "Shape is not an array\n"
595                               "in ()");
596   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsTuple(),
597                               "a shape that represents a tuple",
598                               "Shape is not a tuple\n"
599                               "in f32[1,2]{0,1}");
600   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().IsEffectiveScalar(),
601                               "a shape that is an effective scalar",
602                               "Shape is not an effective scalar\n"
603                               "in f32[1,2]{0,1}");
604   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(42),
605                               "a shape that has 42 dimensions",
606                               "Shape does not have rank 42\n"
607                               "in f32[1,2]{0,1}");
608   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(0),
609                               "a shape that is a scalar",
610                               "Shape is not a scalar\n"
611                               "in f32[1,2]{0,1}");
612   EXPECT_DESC_AND_EXPLANATION(shape, m::Shape().WithRank(1).IsArray(),
613                               "a shape:\n"
614                               " * that has 1 dimension AND\n"
615                               " * that represents an array",
616                               "Shape does not have rank 1\n"
617                               "in f32[1,2]{0,1}");
618   EXPECT_DESC_AND_EXPLANATION(ShapeUtil::MakeNil(),
619                               m::Shape().IsArray().WithRank(1),
620                               "a shape:\n"
621                               " * that represents an array AND\n"
622                               " * that has 1 dimension",
623                               "Shape is not an array\n"
624                               "in ()");
625   EXPECT_DESC_AND_EXPLANATION(
626       ShapeUtil::MakeShapeWithLayout(F32, {1, 2}, {1, 0}),
627       m::Shape().WithLayoutEqualTo(&layout),
628       "a shape with\n  a layout equal to {0,1}",
629       "Layout {1,0} is not equal to expected {0,1}\n"
630       "in f32[1,2]{1,0}");
631   EXPECT_DESC_AND_EXPLANATION(
632       shape, m::Shape().WithLayout(m::Layout().WithSparseFormat()),
633       "a shape with\n  a layout with format SPARSE",
634       "Layout has format DENSE but expected SPARSE\n"
635       "in f32[1,2]{0,1}");
636   EXPECT_DESC_AND_EXPLANATION(shape,
637                               m::Shape().WithSubshapeEqualTo({10}, &shape),
638                               "a shape with subshape at index {10} which is\n"
639                               "  a shape equal to f32[1,2]{0,1}",
640                               "No subshape at {10}\n"
641                               "in f32[1,2]{0,1}");
642   EXPECT_DESC_AND_EXPLANATION(
643       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
644       m::Shape().WithSubshapeEqualTo({0}, &shape),
645       "a shape with subshape at index {0} which is\n"
646       "  a shape equal to f32[1,2]{0,1}",
647       "Shape not equal to f32[1,2]{0,1}\n"
648       "in f32[2,2]{1,0}\n"
649       "in subshape at {0}\n"
650       "in (f32[2,2])");
651   EXPECT_DESC_AND_EXPLANATION(shape,
652                               m::Shape().WithSubshapeCompatibleTo({10}, &shape),
653                               "a shape with subshape at index {10} which is\n"
654                               "  a shape compatible with f32[1,2]",
655                               "No subshape at {10}\n"
656                               "in f32[1,2]{0,1}");
657   EXPECT_DESC_AND_EXPLANATION(
658       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2, 2})}),
659       m::Shape().WithSubshapeCompatibleTo({0}, &shape),
660       "a shape with subshape at index {0} which is\n"
661       "  a shape compatible with f32[1,2]",
662       "Shape not compatible with f32[1,2]\n"
663       "in f32[2,2]{1,0}\n"
664       "in subshape at {0}\n"
665       "in (f32[2,2])");
666   EXPECT_DESC_AND_EXPLANATION(
667       ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({shape})}),
668       m::Shape().WithSubshape({0, 0}, m::Shape().IsScalar()),
669       "a shape with subshape at index {0,0} which is\n"
670       "  a shape that represents a scalar",
671       "Shape is not a scalar\n"
672       "in f32[1,2]{0,1}\n"
673       "in subshape at {0,0}\n"
674       "in ((f32[1,2]))");
675 }
676 
SetName(absl::string_view name,std::unique_ptr<HloInstruction> instr)677 std::unique_ptr<HloInstruction> SetName(absl::string_view name,
678                                         std::unique_ptr<HloInstruction> instr) {
679   instr->SetAndSanitizeName(string(name));
680   return instr;
681 }
682 
TEST(PatternMatcherTest,HloInstructionDescribeToAndExplain)683 TEST(PatternMatcherTest, HloInstructionDescribeToAndExplain) {
684   std::unique_ptr<HloInstruction> iota =
685       SetName("i", HloInstruction::CreateIota(ShapeUtil::MakeShape(S32, {42}),
686                                               /*iota_dimension=*/0));
687   std::unique_ptr<HloInstruction> constant =
688       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
689 
690   EXPECT_DESC_AND_EXPLANATION(static_cast<const HloInstruction*>(nullptr),
691                               m::Op(), "an HloInstruction",
692                               "HloInstruction* is null");
693   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithName("foo"),
694                               "an HloInstruction named \"foo\"",
695                               "HloInstruction not named \"foo\"\n"
696                               "in i = s32[42]{0} iota(), iota_dimension=0");
697   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithOpcode(HloOpcode::kAdd),
698                               "an HloInstruction with opcode add",
699                               "HloInstruction doesn't have opcode add\n"
700                               "in i = s32[42]{0} iota(), iota_dimension=0");
701   EXPECT_DESC_AND_EXPLANATION(
702       constant, m::Op().IsNonConstant(),
703       "an HloInstruction with any opcode other than constant",
704       "HloInstruction has opcode constant, expected anything else\n"
705       "in c = s32[] constant(0)");
706   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithNumOperands(42),
707                               "an HloInstruction with 42 operands",
708                               "HloInstruction doesn't have 42 operands\n"
709                               "in i = s32[42]{0} iota(), iota_dimension=0");
710   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().WithShape(m::Shape().IsTuple()),
711                               "an HloInstruction outputting\n"
712                               "  a shape that represents a tuple",
713                               "Shape is not a tuple\n"
714                               "in s32[42]{0}\n"
715                               "in output shape\n"
716                               "in i = s32[42]{0} iota(), iota_dimension=0");
717   EXPECT_DESC_AND_EXPLANATION(
718       iota, m::Op().WithOperand(2, m::Op().WithOpcode(HloOpcode::kAdd)),
719       "an HloInstruction with operand 2 which is:\n"
720       "  an HloInstruction with opcode add",
721       "desired operand index 2 is out of bounds\n"
722       "in i = s32[42]{0} iota(), iota_dimension=0");
723 
724   EXPECT_DESC_AND_EXPLANATION(
725       SetName("a", HloInstruction::CreateBinary(ShapeUtil::MakeShape(S32, {}),
726                                                 HloOpcode::kAdd, constant.get(),
727                                                 constant.get())),
728       m::Op().WithOperand(1, m::Op().IsNonConstant()),
729       "an HloInstruction with operand 1 which is:\n"
730       "  an HloInstruction with any opcode other than constant",
731       "HloInstruction has opcode constant, expected anything else\n"
732       "in c = s32[] constant(0)\n"
733       "in operand 1\n"
734       "in a = s32[] add(s32[] c, s32[] c)");
735   EXPECT_DESC_AND_EXPLANATION(
736       iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop),
737       "an HloInstruction with fusion kind kLoop",
738       "HloInstruction does not have fusion kind kLoop; it's not a fusion\n"
739       "in i = s32[42]{0} iota(), iota_dimension=0");
740   EXPECT_DESC_AND_EXPLANATION(
741       iota, m::Op().WithTupleIndex(42),
742       "an HloInstruction which is a GTE with index 42",
743       "HloInstruction is not a GTE with index 42; it's not a GTE at all\n"
744       "in i = s32[42]{0} iota(), iota_dimension=0");
745   EXPECT_DESC_AND_EXPLANATION(iota, m::Op().IsConstantScalar(),
746                               "an HloInstruction which is a constant scalar",
747                               "HloInstruction is not a constant\n"
748                               "in i = s32[42]{0} iota(), iota_dimension=0");
749   EXPECT_DESC_AND_EXPLANATION(
750       SetName("c", HloInstruction::CreateConstant(
751                        LiteralUtil::CreateR1<int>({1, 2}))),
752       m::Op().IsConstantEffectiveScalar(),
753       "an HloInstruction which is a constant effective scalar",
754       "HloInstruction is not an effective scalar\n"
755       "in c = s32[2]{0} constant({1, 2})");
756   EXPECT_DESC_AND_EXPLANATION(
757       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))),
758       m::Op().IsConstantScalar(42),
759       "an HloInstruction which is a constant scalar with value 42",
760       "HloInstruction's constant value 10 did not match expected value 42\n"
761       "in c = s32[] constant(10)");
762   EXPECT_DESC_AND_EXPLANATION(
763       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.25))),
764       m::Op().IsConstantEffectiveScalar(1.25),
765       "an HloInstruction which is a constant effective scalar with value 1.25",
766       "HloInstruction's constant value 2.25 did not match expected value 1.25\n"
767       "in c = f64[] constant(2.25)");
768   EXPECT_DESC_AND_EXPLANATION(
769       constant, m::Op().Is(iota.get()),
770       absl::StrCat("an HloInstruction which is 0x", absl::Hex(iota.get()),
771                    " (i = s32[42]{0} iota(), iota_dimension=0)"),
772       absl::StrCat("HloInstruction 0x", absl::Hex(constant.get()), " is not 0x",
773                    absl::Hex(iota.get()),
774                    " (i = s32[42]{0} iota(), iota_dimension=0)\n"
775                    "in c = s32[] constant(0)"));
776 }
777 
TEST(PatternMatcherTest,HloInstructionMatcherAnyOrderDescribeTo)778 TEST(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) {
779   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
780   EXPECT_DESC_AND_EXPLANATION(
781       SetName("a", HloInstruction::CreateBinary(
782                        scalar_s32, HloOpcode::kAdd,
783                        SetName("b", HloInstruction::CreateConstant(
784                                         LiteralUtil::CreateR0(0)))
785                            .get(),
786                        SetName("c", HloInstruction::CreateConstant(
787                                         LiteralUtil::CreateR0(0)))
788                            .get())),
789       m::AddAnyOrder(m::Op().WithName("b"), m::Op().WithName("bar")),
790       "an HloInstruction:\n"
791       " * with opcode add AND\n"
792       " * with two operands in either order:\n"
793       "    - an HloInstruction named \"b\"\n"
794       "    - an HloInstruction named \"bar\"",
795       "HloInstruction's operands (ignoring order) did not match second "
796       "matcher.  Specifically,\n"
797       " - an HloInstruction named \"bar\"\n"
798       "does not match LHS:\n"
799       " - HloInstruction not named \"bar\"\n"
800       "   in b = s32[] constant(0)\n"
801       "does not match RHS:\n"
802       " - HloInstruction not named \"bar\"\n"
803       "   in c = s32[] constant(0)\n"
804       "in a = s32[] add(s32[] b, s32[] c)");
805 
806   EXPECT_DESC_AND_EXPLANATION(
807       SetName("a",
808               HloInstruction::CreateBinary(
809                   scalar_s32, HloOpcode::kAdd,
810                   HloInstruction::CreateParameter(0, scalar_s32, "p").get(),
811                   SetName("c", HloInstruction::CreateConstant(
812                                    LiteralUtil::CreateR0(0)))
813                       .get())),
814       m::AddAnyOrder(m::Op().IsConstantScalar(), m::Op().IsConstant()),
815       "an HloInstruction:\n"
816       " * with opcode add AND\n"
817       " * with two operands in either order:\n"
818       "    - an HloInstruction which is a constant scalar\n"
819       "    - an HloInstruction with opcode constant",
820       "HloInstruction's LHS operand did not match either of the two matchers.  "
821       "Specifically,\n"
822       " - an HloInstruction which is a constant scalar\n"
823       "does not match LHS:\n"
824       " - HloInstruction is not a constant\n"
825       "   in p = s32[] parameter(0)\n"
826       "and\n"
827       " - an HloInstruction with opcode constant\n"
828       "does not match LHS:\n"
829       " - HloInstruction doesn't have opcode constant\n"
830       "   in p = s32[] parameter(0)\n"
831       "in a = s32[] add(s32[] p, s32[] c)");
832 }
833 
TEST(PatternMatcherTest,AnyOfMatcherDescribeToAndExplain)834 TEST(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) {
835   EXPECT_DESC_AND_EXPLANATION(
836       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
837       m::AnyOf<HloInstruction>(m::Op().WithName("foo"),
838                                m::Op().WithName("bar")),
839       "any of:\n"
840       " - an HloInstruction named \"foo\" OR\n"
841       " - an HloInstruction named \"bar\"",
842       "None of the following matchers succeeded:\n"
843       "Matcher #1\n"
844       " - an HloInstruction named \"foo\"\n"
845       "failed with\n"
846       " - HloInstruction not named \"foo\"\n"
847       "   in c = s32[] constant(0)\n"
848       "Matcher #2\n"
849       " - an HloInstruction named \"bar\"\n"
850       "failed with\n"
851       " - HloInstruction not named \"bar\"\n"
852       "   in c = s32[] constant(0)");
853 }
854 
TEST(PatternMatcherTest,Parameter)855 TEST(PatternMatcherTest, Parameter) {
856   auto param =
857       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
858   auto non_param =
859       SetName("c", HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
860   EXPECT_FALSE(Match(param.get(), m::Parameter(0)));
861   EXPECT_TRUE(Match(param.get(), m::Parameter()));
862   EXPECT_TRUE(Match(param.get(), m::Parameter(1)));
863   EXPECT_FALSE(Match(non_param.get(), m::Parameter()));
864   EXPECT_FALSE(Match(non_param.get(), m::Parameter(1)));
865 
866   EXPECT_DESC_AND_EXPLANATION(non_param, m::Parameter(1),
867                               "an HloInstruction:\n"
868                               " * with opcode parameter AND\n"
869                               " * which is parameter 1",
870                               "HloInstruction doesn't have opcode parameter\n"
871                               "in c = s32[] constant(0)");
872   EXPECT_EQ(Explanation(HloInstruction::CreateParameter(
873                             0, ShapeUtil::MakeShape(F32, {}), "p0"),
874                         m::Parameter(1)),
875             "HloInstruction is not parameter 1\n"
876             "in p0 = f32[] parameter(0)");
877 }
878 
TEST(PatternMatcherTest,OneUseAndOneUser)879 TEST(PatternMatcherTest, OneUseAndOneUser) {
880   auto param =
881       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
882 
883   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
884   EXPECT_DESC_AND_EXPLANATION(
885       param, m::Op().WithOneUse(),
886       "an HloInstruction which has exactly one use",
887       "HloInstruction has 0 users, but expected exactly one.\n"
888       "in p0 = f32[] parameter(0)");
889 
890   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
891   EXPECT_DESC_AND_EXPLANATION(
892       param, m::Op().WithOneUser(),
893       "an HloInstruction which has exactly one user (but possibly is used "
894       "multiple times by that instruction)",
895       "HloInstruction has 0 users, but expected exactly one.\n"
896       "in p0 = f32[] parameter(0)");
897 
898   {
899     auto reshape =
900         SetName("r", HloInstruction::CreateReshape(
901                          ShapeUtil::MakeShape(F32, {1}), param.get()));
902     EXPECT_TRUE(Match(param.get(), m::Op().WithOneUse()));
903     EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
904 
905     auto reshape1 =
906         SetName("r1", HloInstruction::CreateReshape(
907                           ShapeUtil::MakeShape(F32, {1}), param.get()));
908     EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
909     EXPECT_FALSE(Match(param.get(), m::Op().WithOneUser()));
910 
911     const char* kMultipleUserExplanation =
912         "HloInstruction has 2 users, but expected exactly one.\n"
913         "All users:\n"
914         " - r = f32[1]{0} reshape(f32[] p0)\n"
915         " - r1 = f32[1]{0} reshape(f32[] p0)\n"
916         "in p0 = f32[] parameter(0)";
917     EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
918               kMultipleUserExplanation);
919     EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUser()),
920               kMultipleUserExplanation);
921   }
922 
923   auto add = SetName("add", HloInstruction::CreateBinary(
924                                 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd,
925                                 param.get(), param.get()));
926   EXPECT_TRUE(Match(param.get(), m::Op().WithOneUser()));
927   EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse()));
928   EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()),
929             "HloInstruction is used 2 times by its user, but is expected to be "
930             "used just once: add = f32[] add(f32[] p0, f32[] p0)\n"
931             "in p0 = f32[] parameter(0)");
932 }
933 
TEST(HloMatchersTest,Comparison)934 TEST(HloMatchersTest, Comparison) {
935   auto shape = ShapeUtil::MakeShape(F32, {1});
936   auto p0 = HloInstruction::CreateParameter(0, shape, "param.0");
937   auto p1 = HloInstruction::CreateParameter(1, shape, "param.1");
938   auto eq = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
939                                           ComparisonDirection::kEq);
940   auto ne = HloInstruction::CreateCompare(shape, p0.get(), p1.get(),
941                                           ComparisonDirection::kNe);
942   auto add =
943       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0.get(), p1.get());
944   auto le = HloInstruction::CreateCompare(shape, p0.get(), add.get(),
945                                           ComparisonDirection::kLe);
946 
947   EXPECT_TRUE(Match(eq.get(), m::Compare()));
948   EXPECT_TRUE(Match(eq.get(), m::Eq()));
949   EXPECT_TRUE(Match(eq.get(), m::Eq(m::Parameter(0), m::Parameter(1))));
950   EXPECT_TRUE(Match(eq.get(), m::EqAnyOrder(m::Parameter(1), m::Parameter(0))));
951   EXPECT_TRUE(Match(ne.get(), m::Compare()));
952   EXPECT_TRUE(Match(ne.get(), m::Ne()));
953   EXPECT_TRUE(Match(
954       le.get(),
955       m::Compare(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
956   EXPECT_TRUE(Match(le.get(), m::Le(m::Parameter(0),
957                                     m::Add(m::Parameter(0), m::Parameter(1)))));
958 
959   EXPECT_FALSE(Match(eq.get(), m::Add()));
960   EXPECT_FALSE(Match(eq.get(), m::Ne()));
961   EXPECT_FALSE(
962       Match(le.get(),
963             m::Eq(m::Parameter(0), m::Add(m::Parameter(0), m::Parameter(1)))));
964   EXPECT_FALSE(Match(eq.get(), m::Eq(m::Parameter(1), m::Parameter(0))));
965   EXPECT_DESC_AND_EXPLANATION(
966       eq, m::Ne().WithOneUser(),
967       "an HloInstruction:\n"
968       " * with opcode compare AND\n"
969       " * which has comparison direction NE AND\n"
970       " * which has exactly one user (but possibly is used "
971       "multiple times by that instruction)",
972       "HloInstruction is not comparison NE\n"
973       "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), "
974       "direction=EQ");
975 }
976 
977 }  // namespace
978 }  // namespace xla
979