1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
18 
19 #include "absl/strings/str_replace.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/utility/utility.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 
30 namespace xla {
31 
32 // A pattern matcher for HloInstructions, Shapes, and Layouts.
33 //
34 // The Match function's first argument must be HloInstruction*, Shape*, or
35 // Layout*. The second argument is a pattern that will be matched against the
36 // first argument, as described below.
37 //
38 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
39 // functions. By default, the returned patterns will match any HloInstruction,
40 // Shape, or Layout, respectively. However the match can be made more specific
41 // by using the pattern's modifier methods, for example:
42 //
43 //   match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
44 //     0, match::Op().WithOpcode(HloOpcode::kConstant))
45 //
46 // This pattern will match Add instructions whose first operand is a constant.
47 //
48 // Each pattern type has the following modifiers, which are described where
49 // nontrivial.
50 //
51 //   Op():
52 //     - Is: is the given HloInstruction* (i.e. pointer equality)
53 //     - WithName
54 //     - WithOpcode
55 //     - WithoutOpcode: anything other than the given opcode
56 //     - WithShape: instr's shape matches the given pattern
57 //     - WithShapeEqualTo: instr's shape is equal to the given Shape
58 //     - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
59 //     - WithNumOperands
60 //     - WithOperand: operand at the given index matches the given pattern
61 //     - IsConstant
62 //     - IsNonConstant
63 //     - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
64 //       e.g. IsConstantScalar() or IsConstantScalar(42).
65 //     - WithFusionKind
66 //     - WithTupleIndex: get-tuple-element operations with the given tuple index
67 //     - WithOneUse: Instruction is used as an operand exactly once.
68 //     - WithOneUser: Instruction is used by exactly one other instruction, but
69 //       is possibly used more than once as an operand (e.g. multiply(x,x)).
70 //     - WithComparisonDirection: instr has the given direction
71 //
72 //   Shape():
73 //     - EqualTo
74 //     - CompatibleTo
75 //     - IsScalar/IsEffectiveScalar/IsArray/IsTuple
76 //     - IsDenseArray/IsSparseArray
77 //     - WithLayout: layout shape's layout matches the given pattern (e.g.
78 //       Layout().WithDenseFormat())
79 //     - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
80 //       Layout, but not the result of Layout().foo())
81 //     - WithSubshape: shape is a tuple whose subshape matches the given pattern
82 //       (e.g. Shape().IsScalar()).
83 //     - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
84 //       (i.e. another Shape, but not the result of Shape().foo())
85 //     - WithElementType: shape is an array/scalar with the given elem type
86 //     - WithRank: shape is an array/scalar with the given rank
87 //
88 //  Layout():
89 //     - EqualTo
90 //     - WithDenseFormat/WithSparseFormat
91 //
92 // Op(), Shape(), and Layout() may be passed an argument of type
93 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
94 // these pointers. If the pattern is matched, the address of the matched value
95 // will be "captured" and stored at this location.
96 //
97 // For example:
98 //   HloInstruction* foo = ...;
99 //   HloInstruction* matched_operand;
100 //   CHECK(Match(foo,
101 //               match::Op().WithOperand(0, match::Op(&matched_operand))));
102 //
103 // Helpers are provided for most HLO instructions. These helpers can be called
104 // with no arguments, in which case they will match any instruction matching the
105 // opcode. They may also be called with matches for the operands and with an
106 // optional capture. (The capture must be the first argument.) Some examples of
107 // these helpers and their equivalents are provided below.
108 
109 // Example nullary instruction:
110 //   Parameter()                    == Op().WithOpcode(HloOpcode::kParameter)
111 //   Parameter(&a)                  == Op(&a).WithOpcode(HloOpcode::kParameter)
112 //
113 // Example unary instruction:
114 //   Abs()                          == Op().WithOpcode(HloOpcode::kAbs)
115 //   Abs(Op(&a))                    == Op().WithOpcode(HloOpcode::kAbs)
116 //                                         .WithOperand(0, Op(&a)))
117 //   Abs(&a, Op(&b))                == Op(&a).WithOpcode(HloOpcode::kAbs)
118 //                                           .WithOperand(0, Op(&b))
119 //
120 // Commutative binary instructions have a special form that accepts either order
121 // of args, e.g.:
122 //
123 //   AddAnyOrder(Parameter(1), Abs()) ==
124 //     Op().WithOpcode(HloOpcode::kAdd)
125 //         .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
126 //
127 //   MultiplyAnyOrder(&a, Parameter(), Abs())  // Captures the mul in `a`.
128 //
129 // The following additional helpers are provided.  In all cases, `&a` is
130 // optional.
131 //
132 //   ConstantScalar(&a)               == Op(&a).IsConstantScalar();
133 //   ConstantScalar(&a, v)            == Op(&a).IsConstantScalar(v);
134 //   ConstantEffectiveScalar(&a)      == Op(&a).IsConstantEffectiveScalar();
135 //   ConstantEffectiveScalar(&a, v)   == Op(&a).IsConstantEffectiveScalar(&a, v)
136 //   NonConstant(&a)                  == Op(&a).IsNonConstant()
137 //   GetTupleElement(&a, b, index)    == Op(&a).WithTupleIndex(index)
138 //                                             .WithOperand(0, b);
139 //   Parameter(&a, n)                 == Op(&a).WithParameterNum(n);
140 
141 struct MatchOption {
142   // If true, actually capture matched item into the user pointer.
143   bool capture;
144 
145   // An explanation for why we failed to match is streamed here, if not-null.
146   std::ostream* explain_os;
147 };
148 
149 template <typename Value, typename Pattern>
150 bool Match(Value* value, const Pattern& pattern,
151            MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
152   if (option.capture) {
153     auto new_option = option;
154     new_option.capture = false;
155     if (!pattern.Match(value, new_option)) {
156       return false;
157     }
158   }
159   return pattern.Match(value, option);
160 }
161 
162 namespace match {
163 
164 namespace detail {
165 
166 // Macro for streaming to option.explain_os if it's not null.
167 //
168 //   EXPLAIN << "value of foo(): " << foo()
169 //
170 #pragma push_macro("EXPLAIN")
171 #define EXPLAIN \
172   if (option.explain_os) *option.explain_os
173 
174 // kIndentInc is the additional number of spaces that we indent by when we
175 // increase the indent "by one".
176 enum {
177   kIndentInc = 2,
178 };
179 
180 // Writes a newline and then `indent` spaces.
181 //
182 // We follow an unintuitive convention in this file's pretty-printers: Indents
183 // are performed by the caller, not the callee.  For example, if you want to
184 // print
185 //
186 //   foo:
187 //    - bar
188 //
189 // you'd do:
190 //
191 //  Foo::DescribeTo(std::ostream* os, int64 indent) {
192 //    *os << "foo:";
193 //    Indent(os, indent)  // Create a newline at the *current* indent level.
194 //    *os << " - ";
195 //    bar.DescribeTo(os, indent + 3);  // + 3 because strlen(" * ") == 3.
196 //  }
197 //
198 //  Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; }
199 //
200 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
201 // performed by Foo.  This convention allows the caller to decide whether a
202 // matcher is preceded by a newline, which is important e.g. for the AllOf
203 // matcher.
204 //
205 // (Incidentally, indenting in Match's explanations is handled differently.
206 // Indents are a common case in DescribeTo [we're printing a whole tree], but
207 // they're a special case in Match [we're printing only a path through the tree
208 // that encounters a failing node]. Indents in Match only appear when we
209 // encounter a failing disjunction, so we just handle them as a special case
210 // there.)
Indent(std::ostream * os,int64 indent)211 inline void Indent(std::ostream* os, int64 indent) {
212   *os << "\n";
213   for (int64 i = 0; i < indent; ++i) {
214     *os << " ";
215   }
216 }
217 
218 // SFINAE template that determines whether T declares a static member
219 // kIsTrivialMatcher.
220 //
221 // Trivial matchers get special treatment.  For example, when printing
222 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
223 // yields e.g.
224 //    "a shape compatible with f32[1,2]"
225 // rather than
226 //    "a shape AND compatible with f32[1,2]"
227 template <typename T, typename Dummy = void>
228 struct IsTrivialMatcher {
229   static constexpr bool value = false;
230 };
231 template <typename T>
232 struct IsTrivialMatcher<T,
233                         typename std::enable_if<T::kIsTrivialMatcher>::type> {
234   static constexpr bool value = true;
235 };
236 
237 template <typename Item, typename... Patterns>
238 class AllOfPattern {
239  public:
240   explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
241 
242   bool Match(const Item* item, MatchOption option) const {
243     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
244     // This invariant is guaranteed by the top-level Match and AnyOf.
245     DCHECK(matched || !option.capture);
246     return matched;
247   }
248 
249   bool Match(Item* item, MatchOption option) const {
250     bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
251     // This invariant is guaranteed by the top-level Match and AnyOf.
252     DCHECK(matched || !option.capture);
253     return matched;
254   }
255 
256   void DescribeTo(std::ostream* os, int64 indent = 0) const {
257     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
258   }
259 
260   // Accessor for patterns_.  Please don't use this outside of this file.
261   const std::tuple<Patterns...>& patterns() const { return patterns_; }
262 
263  private:
264   template <typename ItemType, size_t index>
265   bool MatchImpl(ItemType* item, MatchOption option,
266                  std::integral_constant<size_t, index>) const {
267     // We don't need to do any EXPLAINing here; it's all correctly handled by
268     // our sub-matchers (if any fail).
269     return std::get<index>(patterns_).Match(item, option) &&
270            MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
271   }
272 
273   template <typename ItemType>
274   bool MatchImpl(ItemType* item, MatchOption option,
275                  std::integral_constant<size_t, sizeof...(Patterns)>) const {
276     return true;
277   }
278 
279   // Pretty-printing a conjunction has some special cases to make it easy to
280   // read in the simple (common) case.
281   //
282   // If sizeof...(Patterns) == 1, prints as e.g.
283   //
284   //   a shape
285   //
286   // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
287   // shape") prints as
288   //
289   //   a shape compatible with f32[1,2]
290   //
291   // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
292   //
293   //   a shape:
294   //    * compatible with f32[1,2] AND
295   //    * that represents a scalar
296   //
297   // Otherwise prints as:
298   //
299   //   all of:
300   //    * foo AND
301   //    * bar
302   //
303   template <size_t index>
304   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
305                       int64 indent) const {
306     constexpr bool first_is_trivial =
307         IsTrivialMatcher<typename std::remove_reference<decltype(
308             std::get<0>(patterns_))>::type>::value;
309     constexpr bool is_last = index == sizeof...(Patterns) - 1;
310     const auto& submatcher = std::get<index>(patterns_);
311 
312     auto print_bulleted_item = [&] {
313       *os << " * ";
314       submatcher.DescribeTo(os, indent + 3);
315       if (!is_last) {
316         *os << " AND";
317         Indent(os, indent);
318       }
319     };
320 
321     if (index == 0) {
322       if (first_is_trivial || is_last) {
323         submatcher.DescribeTo(os, indent + kIndentInc);
324         if (sizeof...(Patterns) > 2) {
325           *os << ":";
326           Indent(os, indent);
327         }
328       } else {
329         *os << "all of:";
330         Indent(os, indent);
331         print_bulleted_item();
332       }
333     } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
334       *os << " ";
335       submatcher.DescribeTo(os, indent);
336     } else {
337       print_bulleted_item();
338     }
339     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
340   }
341 
342   void DescribeToImpl(std::ostream* os,
343                       std::integral_constant<size_t, sizeof...(Patterns)>,
344                       int64 indent) const {}
345 
346   std::tuple<Patterns...> patterns_;
347 };
348 
349 }  // namespace detail
350 
351 // Returns a pattern that represents the conjunction of all input patterns. All
352 // patterns need to match in order to have the AllOf pattern match.
353 template <typename Item, typename... Patterns>
354 detail::AllOfPattern<typename std::remove_const<Item>::type, Patterns...> AllOf(
355     const Patterns&... patterns) {
356   return detail::AllOfPattern<typename std::remove_const<Item>::type,
357                               Patterns...>(patterns...);
358 }
359 
360 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
361 //
362 // This transformation is necessary for good pretty-printing.
363 template <typename Item, typename... InnerPs, typename... OuterPs>
364 detail::AllOfPattern<typename std::remove_const<Item>::type, InnerPs...,
365                      OuterPs...>
366 AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
367       const OuterPs&... outer_ps) {
368   // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
369   auto make_all_of = [](const InnerPs&... inner_ps,
370                         const OuterPs&... outer_ps) {
371     return detail::AllOfPattern<typename std::remove_const<Item>::type,
372                                 InnerPs..., OuterPs...>(inner_ps...,
373                                                         outer_ps...);
374   };
375   return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
376                                                  std::make_tuple(outer_ps...)));
377 }
378 
379 namespace detail {
380 
381 template <typename LayoutType, typename Impl>
382 class LayoutPattern;
383 
384 // The base LayoutPattern implementation. Matches only if the layout is not
385 // nullptr.
386 class LayoutPatternBaseImpl {
387  public:
388   bool Match(const ::xla::Layout* layout, MatchOption option) const {
389     if (layout == nullptr) {
390       EXPLAIN << "Layout is null";
391       return false;
392     }
393     return true;
394   }
395 
396   void DescribeTo(std::ostream* os, int64 indent = 0) const {
397     *os << "a layout";
398   }
399 
400   static constexpr bool kIsTrivialMatcher = true;
401 };
402 
403 // A LayoutPattern implementation that matches only if the layout equals a
404 // Layout proto.
405 class LayoutPatternEqualImpl {
406  public:
407   explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
408       : layout_(layout) {}
409 
410   bool Match(const ::xla::Layout* layout, MatchOption option) const {
411     if (!LayoutUtil::Equal(*layout_, *layout)) {
412       EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
413               << " is not equal to expected "
414               << LayoutUtil::HumanString(*layout_);
415       return false;
416     }
417     return true;
418   }
419 
420   void DescribeTo(std::ostream* os, int64 indent = 0) const {
421     *os << "equal to " << LayoutUtil::HumanString(*layout_);
422   }
423 
424  private:
425   const ::xla::Layout* layout_;
426 };
427 
428 // A LayoutPattern implementation that matches only if the layout has a given
429 // format.
430 class LayoutPatternFormatImpl {
431  public:
432   explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
433 
434   bool Match(const ::xla::Layout* layout, MatchOption option) const {
435     if (layout->format() != format_) {
436       EXPLAIN << "Layout has format " << Format_Name(layout->format())
437               << " but expected " << Format_Name(format_);
438       return false;
439     }
440     return true;
441   }
442 
443   void DescribeTo(std::ostream* os, int64 indent = 0) const {
444     *os << "with format " << Format_Name(format_);
445   }
446 
447  private:
448   Format format_;
449 };
450 
451 // A pattern that matches Layouts.
452 template <typename LayoutType, typename Impl>
453 class LayoutPattern {
454  private:
455   template <typename NewImpl>
456   auto AppendImpl(NewImpl new_impl) const
457       -> LayoutPattern<LayoutType,
458                        decltype(AllOf<Layout>(std::declval<Impl>(),
459                                               std::move(new_impl)))> {
460     auto new_allof = AllOf<Layout>(impl_, std::move(new_impl));
461     return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
462                                                           matched_layout_);
463   }
464 
465  public:
466   explicit constexpr LayoutPattern(const Impl& impl,
467                                    LayoutType** matched_layout)
468       : impl_(impl), matched_layout_(matched_layout) {}
469 
470   // Returns true and captures the layout iff it matches the pattern.
471   bool Match(const ::xla::Layout* layout, MatchOption option) const {
472     if (impl_.Match(layout, option)) {
473       if (option.capture && matched_layout_) {
474         *matched_layout_ = layout;
475       }
476       return true;
477     }
478     return false;
479   }
480 
481   // Returns true and captures the layout iff it matches the pattern.
482   bool Match(::xla::Layout* layout, MatchOption option) const {
483     if (impl_.Match(layout, option)) {
484       if (option.capture && matched_layout_) {
485         *matched_layout_ = layout;
486       }
487       return true;
488     }
489     return false;
490   }
491 
492   void DescribeTo(std::ostream* os, int64 indent = 0) const {
493     impl_.DescribeTo(os, indent);
494   }
495 
496   // Modifies the pattern to match only if the layout equals the given proto.
497   // The layout must outlive the returned pattern.
498   constexpr auto EqualTo(const ::xla::Layout* layout) const
499       -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) {
500     return AppendImpl(LayoutPatternEqualImpl(layout));
501   }
502 
503   // Modifies the pattern to match only if the layout has a dense format.
504   constexpr auto WithDenseFormat() const
505       -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) {
506     return AppendImpl(LayoutPatternFormatImpl(DENSE));
507   }
508 
509   // Modifies the pattern to match only if the layout has a sparse format.
510   constexpr auto WithSparseFormat() const
511       -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) {
512     return AppendImpl(LayoutPatternFormatImpl(SPARSE));
513   }
514 
515  private:
516   Impl impl_;
517   LayoutType** matched_layout_;
518 };
519 
520 template <typename Item, typename... Patterns>
521 class AnyOfPattern {
522  public:
523   explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
524 
525   bool Match(const Item* item, MatchOption option) const {
526     return MatchImpl(item, option);
527   }
528 
529   bool Match(Item* item, MatchOption option) const {
530     return MatchImpl(item, option);
531   }
532 
533   void DescribeTo(std::ostream* os, int64 indent = 0) const {
534     *os << "any of:";
535     Indent(os, indent);
536     DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
537   }
538 
539  private:
540   template <typename ItemType>
541   bool MatchImpl(ItemType* item, MatchOption option) const {
542     // If we're generating an explanation, buffer it until we know we failed.
543     absl::optional<std::stringstream> explanation;
544     MatchOption new_option = option;
545     if (option.explain_os) {
546       new_option.explain_os = &explanation.emplace();
547     }
548     bool rv = MatchRecursiveImpl(item, new_option,
549                                  std::integral_constant<size_t, 0>());
550     if (!rv && option.explain_os) {
551       EXPLAIN << "None of the following matchers succeeded:";
552       EXPLAIN << explanation->str();
553     }
554     return rv;
555   }
556 
557   template <typename ItemType, size_t index>
558   bool MatchRecursiveImpl(ItemType* item, MatchOption option,
559                           std::integral_constant<size_t, index>) const {
560     auto new_option = option;
561     new_option.capture = false;
562 
563     absl::optional<std::stringstream> explanation;
564     if (option.explain_os) {
565       new_option.explain_os = &explanation.emplace();
566     }
567 
568     // Try to match the sub-pattern without capturing behavior.
569     if (std::get<index>(patterns_).Match(item, new_option)) {
570       // Capture the branch.
571       if (option.capture) {
572         // TODO(timshen): Currently the behavior can be exponential. Optimize it
573         // with memoization or recording the matched sub-pattern index, if it
574         // takes too long to run.
575         //
576         // Specifically, the "memoization" approach is to create an empty
577         // container with the key (pattern, instruction), and value as whether
578         // matched or not.
579         //
580         // Alternatively, we may run the pattern matching with captures off, but
581         // instead record a "trace" somewhere, indicating how exactly the
582         // pattern matches the input. For example, the trace information for
583         // AnyOf will be a runtime number indicate which sub-pattern is matched.
584         // Then we run another pass to do captures only with the help of the
585         // trace.
586         bool matched = std::get<index>(patterns_).Match(item, option);
587         DCHECK(matched);
588       }
589       return true;
590     }
591     if (option.explain_os) {
592       EXPLAIN << "\nMatcher #" << index + 1;
593       EXPLAIN << "\n - ";
594       std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
595       EXPLAIN << "\nfailed with";
596       EXPLAIN << "\n - ";
597       EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n   "}});
598     }
599     return MatchRecursiveImpl(item, option,
600                               std::integral_constant<size_t, index + 1>());
601   }
602 
603   template <typename ItemType>
604   bool MatchRecursiveImpl(
605       ItemType* item, MatchOption option,
606       std::integral_constant<size_t, sizeof...(Patterns)>) const {
607     return false;
608   }
609 
610   template <size_t index>
611   void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
612                       int64 indent) const {
613     *os << " - ";
614     std::get<index>(patterns_).DescribeTo(os, indent + 3);
615     if (index != sizeof...(Patterns) - 1) {
616       *os << " OR";
617       Indent(os, indent);
618     }
619     DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
620   }
621 
622   void DescribeToImpl(std::ostream* os,
623                       std::integral_constant<size_t, sizeof...(Patterns)>,
624                       int64 indent) const {}
625 
626   std::tuple<Patterns...> patterns_;
627 };
628 
629 }  // namespace detail
630 
631 // Returns a pattern that represents the logical disjunction of the input
632 // patterns. The returned pattern matches from left to right, and stops on the
633 // first match.
634 template <typename Item, typename... Patterns>
635 detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf(
636     const Patterns&... patterns) {
637   return detail::AnyOfPattern<typename std::remove_const<Item>::type,
638                               Patterns...>(patterns...);
639 }
640 
641 // Creates a layout pattern that will capture the matched layout in the
642 // argument.
643 inline constexpr detail::LayoutPattern<const ::xla::Layout,
644                                        detail::LayoutPatternBaseImpl>
645 Layout(const ::xla::Layout** matched_layout = nullptr) {
646   return detail::LayoutPattern<const ::xla::Layout,
647                                detail::LayoutPatternBaseImpl>(
648       detail::LayoutPatternBaseImpl(), matched_layout);
649 }
650 
651 // Creates a layout pattern that will capture the matched layout in the
652 // argument.
653 inline constexpr detail::LayoutPattern<::xla::Layout,
654                                        detail::LayoutPatternBaseImpl>
655 Layout(::xla::Layout** matched_layout) {
656   return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
657       detail::LayoutPatternBaseImpl(), matched_layout);
658 }
659 
660 namespace detail {
661 
662 template <typename ShapeType, typename Impl>
663 class ShapePattern;
664 
665 // The base ShapePattern implementation. Matches only if the shape is not
666 // nullptr.
667 class ShapePatternBaseImpl {
668  public:
669   bool Match(const ::xla::Shape* shape, MatchOption option) const {
670     if (shape == nullptr) {
671       EXPLAIN << "Shape is null";
672     }
673     return shape != nullptr;
674   }
675 
676   void DescribeTo(std::ostream* os, int64 indent = 0) const {
677     *os << "a shape";
678   }
679 
680   static constexpr bool kIsTrivialMatcher = true;
681 };
682 
683 // A ShapePattern implementation that matches only if the shape equals a Shape
684 // proto.
685 class ShapePatternEqualImpl {
686  public:
687   explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
688       : shape_(shape) {}
689 
690   bool Match(const ::xla::Shape* shape, MatchOption option) const {
691     if (!ShapeUtil::Equal(*shape_, *shape)) {
692       EXPLAIN << "Shape not equal to "
693               << ShapeUtil::HumanStringWithLayout(*shape_);
694       return false;
695     }
696     return true;
697   }
698 
699   void DescribeTo(std::ostream* os, int64 indent = 0) const {
700     *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
701   }
702 
703  private:
704   const ::xla::Shape* shape_;
705 };
706 
707 // A ShapePattern implementation that matches only if the shape is compatible to
708 // a Shape proto.
709 class ShapePatternCompatibleImpl {
710  public:
711   explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
712       : shape_(shape) {}
713 
714   bool Match(const ::xla::Shape* shape, MatchOption option) const {
715     if (!ShapeUtil::Compatible(*shape_, *shape)) {
716       EXPLAIN << "Shape not compatible with "
717               << ShapeUtil::HumanString(*shape_);
718       return false;
719     }
720     return true;
721   }
722 
723   void DescribeTo(std::ostream* os, int64 indent = 0) const {
724     *os << "compatible with " << ShapeUtil::HumanString(*shape_);
725   }
726 
727  private:
728   const ::xla::Shape* shape_;
729 };
730 
731 // A ShapePattern implementation that matches only if the shape has a given
732 // element type.
733 class ShapePatternElementTypeImpl {
734  public:
735   explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
736       : element_type_(element_type) {}
737 
738   bool Match(const ::xla::Shape* shape, MatchOption option) const {
739     if (shape->element_type() != element_type_) {
740       EXPLAIN << "Shape does not have element type "
741               << PrimitiveType_Name(element_type_);
742       return false;
743     }
744     return true;
745   }
746 
747   void DescribeTo(std::ostream* os, int64 indent = 0) const {
748     *os << "with element type " << PrimitiveType_Name(element_type_);
749   }
750 
751  private:
752   PrimitiveType element_type_;
753 };
754 
755 // A ShapePattern implementation that matches only if the shape is scalar.
756 class ShapePatternIsScalarImpl {
757  public:
758   explicit constexpr ShapePatternIsScalarImpl() {}
759 
760   bool Match(const ::xla::Shape* shape, MatchOption option) const {
761     if (!ShapeUtil::IsScalar(*shape)) {
762       EXPLAIN << "Shape is not a scalar";
763       return false;
764     }
765     return true;
766   }
767 
768   void DescribeTo(std::ostream* os, int64 indent = 0) const {
769     *os << "that represents a scalar";
770   }
771 };
772 
773 // A ShapePattern implementation that matches only if the shape is an array
774 class ShapePatternIsArrayImpl {
775  public:
776   explicit constexpr ShapePatternIsArrayImpl() {}
777 
778   bool Match(const ::xla::Shape* shape, MatchOption option) const {
779     if (!shape->IsArray()) {
780       EXPLAIN << "Shape is not an array";
781       return false;
782     }
783     return true;
784   }
785 
786   void DescribeTo(std::ostream* os, int64 indent = 0) const {
787     *os << "that represents an array";
788   }
789 };
790 
791 // A ShapePattern implementation that matches only if the shape is a tuple.
792 class ShapePatternIsTupleImpl {
793  public:
794   explicit constexpr ShapePatternIsTupleImpl() {}
795 
796   bool Match(const ::xla::Shape* shape, MatchOption option) const {
797     if (!shape->IsTuple()) {
798       EXPLAIN << "Shape is not a tuple";
799       return false;
800     }
801     return true;
802   }
803 
804   void DescribeTo(std::ostream* os, int64 indent = 0) const {
805     *os << "that represents a tuple";
806   }
807 };
808 
809 // A ShapePattern implementation that matches only if the shape is an effective
810 // scalar.
811 class ShapePatternEffectiveScalarImpl {
812  public:
813   explicit constexpr ShapePatternEffectiveScalarImpl() {}
814 
815   bool Match(const ::xla::Shape* shape, MatchOption option) const {
816     if (!ShapeUtil::IsEffectiveScalar(*shape)) {
817       EXPLAIN << "Shape is not an effective scalar";
818       return false;
819     }
820     return true;
821   }
822 
823   void DescribeTo(std::ostream* os, int64 indent = 0) const {
824     *os << "that is an effective scalar";
825   }
826 };
827 
828 // A ShapePattern implementation that matches only if the shape has a given
829 // rank.
830 class ShapePatternRankImpl {
831  public:
832   explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
833 
834   bool Match(const ::xla::Shape* shape, MatchOption option) const {
835     if (shape->rank() != rank_) {
836       if (rank_ == 0) {
837         EXPLAIN << "Shape is not a scalar";
838       } else {
839         EXPLAIN << "Shape does not have rank " << rank_;
840       }
841       return false;
842     }
843     return true;
844   }
845 
846   void DescribeTo(std::ostream* os, int64 indent = 0) const {
847     if (rank_ == 0) {
848       *os << "that is a scalar";
849     } else {
850       *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
851     }
852   }
853 
854  private:
855   int64 rank_;
856 };
857 
858 // A ShapePattern implementation that matches only if the shape has a layout
859 // that matches a given pattern.
860 template <typename LayoutType, typename LayoutImpl>
861 class ShapePatternLayoutImpl {
862  public:
863   explicit constexpr ShapePatternLayoutImpl(
864       const LayoutPattern<LayoutType, LayoutImpl>& layout)
865       : layout_(layout) {}
866 
867   bool Match(const ::xla::Shape* shape, MatchOption option) const {
868     return LayoutUtil::HasLayout(*shape) &&
869            layout_.Match(&shape->layout(), option);
870   }
871 
872   bool Match(Shape* shape, MatchOption option) const {
873     if (!LayoutUtil::HasLayout(*shape)) {
874       EXPLAIN << "Shape does not have a layout";
875       return false;
876     }
877     if (!layout_.Match(shape->mutable_layout(), option)) {
878       EXPLAIN << "\nin layout";
879       return false;
880     }
881     return true;
882   }
883 
884   void DescribeTo(std::ostream* os, int64 indent = 0) const {
885     *os << "with";
886     Indent(os, indent + kIndentInc);
887     layout_.DescribeTo(os, indent + kIndentInc);
888   }
889 
890  private:
891   LayoutPattern<LayoutType, LayoutImpl> layout_;
892 };
893 
894 // A ShapePattern implementation that matches only if the shape has a subshape
895 // that matches a given pattern.
896 template <typename SubshapeType, typename SubshapeImpl>
897 class ShapePatternSubshapeImpl {
898  public:
899   explicit ShapePatternSubshapeImpl(
900       ShapeIndexView index,
901       const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
902       : index_(index), subshape_(subshape) {}
903 
904   bool Match(const ::xla::Shape* shape, MatchOption option) const {
905     return MatchImpl(shape, option);
906   }
907 
908   bool Match(::xla::Shape* shape, MatchOption option) const {
909     return MatchImpl(shape, option);
910   }
911 
912   void DescribeTo(std::ostream* os, int64 indent = 0) const {
913     *os << "with subshape at index " << index_.ToString() << " which is";
914     Indent(os, indent + kIndentInc);
915     subshape_.DescribeTo(os, indent + kIndentInc);
916   }
917 
918  private:
919   Shape* GetSubshape(Shape* shape) const {
920     return ShapeUtil::GetMutableSubshape(shape, index_);
921   }
922   const Shape* GetSubshape(const Shape* shape) const {
923     return &ShapeUtil::GetSubshape(*shape, index_);
924   }
925 
926   template <typename ShapeType>
927   bool MatchImpl(ShapeType* shape, MatchOption option) const {
928     if (!ShapeUtil::IndexIsValid(*shape, index_)) {
929       EXPLAIN << "No subshape at " << index_.ToString();
930       return false;
931     }
932     if (!subshape_.Match(GetSubshape(shape), option)) {
933       EXPLAIN << "\nin subshape at " << index_.ToString();
934       return false;
935     }
936     return true;
937   }
938 
939   ShapeIndexView index_;
940   ShapePattern<SubshapeType, SubshapeImpl> subshape_;
941 };
942 
943 // A pattern that matches Shapes.
944 template <typename ShapeType, typename Impl>
945 class ShapePattern {
946  private:
947   template <typename NewImpl>
948   auto AppendImpl(NewImpl new_impl) const
949       -> ShapePattern<ShapeType, decltype(AllOf<Shape>(std::declval<Impl>(),
950                                                        std::move(new_impl)))> {
951     auto new_all_of = AllOf<Shape>(impl_, std::move(new_impl));
952     return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
953                                                          matched_shape_);
954   }
955 
956  public:
957   explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
958       : impl_(impl), matched_shape_(matched_shape) {}
959 
960   // Returns true and captures the shape iff it matches the pattern.
961   bool Match(const ::xla::Shape* shape, MatchOption option) const {
962     if (impl_.Match(shape, option)) {
963       if (option.capture && matched_shape_) {
964         *matched_shape_ = shape;
965       }
966       return true;
967     }
968     if (shape) {
969       EXPLAIN << "\nin "
970               << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
971                                       : ShapeUtil::HumanString(*shape));
972     }
973     return false;
974   }
975 
976   // Returns true and captures the shape iff it matches the pattern.
977   bool Match(::xla::Shape* shape, MatchOption option) const {
978     if (impl_.Match(shape, option)) {
979       if (option.capture && matched_shape_) {
980         *matched_shape_ = shape;
981       }
982       return true;
983     }
984     EXPLAIN << "\nin "
985             << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
986                                     : ShapeUtil::HumanString(*shape));
987     return false;
988   }
989 
990   void DescribeTo(std::ostream* os, int64 indent = 0) const {
991     return impl_.DescribeTo(os, indent);
992   }
993 
994   // Modifies the pattern to match only if the shape equals the given proto.
995   // The layout must outlive the returned pattern.
996   constexpr auto EqualTo(const ::xla::Shape* shape) const
997       -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) {
998     return AppendImpl(ShapePatternEqualImpl(shape));
999   }
1000 
1001   // Modifies the pattern to match only if the shape is compatible to the given
1002   // proto. The layout must outlive the returned pattern.
1003   constexpr auto CompatibleTo(const ::xla::Shape* shape) const
1004       -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) {
1005     return AppendImpl(ShapePatternCompatibleImpl(shape));
1006   }
1007 
1008   // Modifies the pattern to match only if the shape has the given element type.
1009   constexpr auto WithElementType(PrimitiveType element_type) const
1010       -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) {
1011     return AppendImpl(ShapePatternElementTypeImpl(element_type));
1012   }
1013 
1014   // Modifies the pattern to match only if the shape is scalar.
1015   constexpr auto IsScalar() const
1016       -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) {
1017     return AppendImpl(ShapePatternIsScalarImpl());
1018   }
1019 
1020   // Modifies the pattern to match only if the shape is an array.
1021   constexpr auto IsArray() const
1022       -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) {
1023     return AppendImpl(ShapePatternIsArrayImpl());
1024   }
1025 
1026   // Modifies the pattern to match only if the shape is a tuple.
1027   constexpr auto IsTuple() const
1028       -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) {
1029     return AppendImpl(ShapePatternIsTupleImpl());
1030   }
1031 
1032   constexpr auto IsEffectiveScalar() const
1033       -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) {
1034     return AppendImpl(ShapePatternEffectiveScalarImpl());
1035   }
1036 
1037   // Modifies the pattern to match only if the shape has the given rank.
1038   constexpr auto WithRank(int64 rank) const
1039       -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) {
1040     return AppendImpl(ShapePatternRankImpl(rank));
1041   }
1042 
1043   // Modifies the pattern to match only if the shape has a layout that matches
1044   // the given pattern.
1045   template <typename LayoutType, typename LayoutImpl>
1046   auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const
1047       -> decltype(this->AppendImpl(
1048           ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout))) {
1049     return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
1050   }
1051 
1052   constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const
1053       -> decltype(this->WithLayout(Layout().EqualTo(layout))) {
1054     return WithLayout(Layout().EqualTo(layout));
1055   }
1056 
1057   constexpr auto IsDenseArray() const
1058       -> decltype(this->WithLayout(Layout().WithDenseFormat())) {
1059     return WithLayout(Layout().WithDenseFormat());
1060   }
1061 
1062   constexpr auto IsSparseArray() const
1063       -> decltype(this->WithLayout(Layout().WithSparseFormat())) {
1064     return WithLayout(Layout().WithSparseFormat());
1065   }
1066 
1067   // Modifies the pattern to match only if the shape has a subshape that matches
1068   // the given pattern.
1069   template <typename SubshapeType, typename SubshapeImpl>
1070   auto WithSubshape(ShapeIndexView index,
1071                     const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
1072       const -> decltype(this->AppendImpl(
1073           ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index,
1074                                                                subshape))) {
1075     return AppendImpl(
1076         ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
1077   }
1078 
1079   ShapePattern<ShapeType,
1080                AllOfPattern<Shape, Impl,
1081                             ShapePatternSubshapeImpl<
1082                                 const ::xla::Shape,
1083                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1084                                              ShapePatternEqualImpl>>>>
1085   WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
1086     return WithSubshape(index,
1087                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1088                             ShapePatternBaseImpl(), nullptr)
1089                             .EqualTo(shape));
1090   }
1091 
1092   ShapePattern<ShapeType,
1093                AllOfPattern<Shape, Impl,
1094                             ShapePatternSubshapeImpl<
1095                                 const ::xla::Shape,
1096                                 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1097                                              ShapePatternCompatibleImpl>>>>
1098   WithSubshapeCompatibleTo(ShapeIndexView index,
1099                            const ::xla::Shape* shape) const {
1100     return WithSubshape(index,
1101                         ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1102                             ShapePatternBaseImpl(), nullptr)
1103                             .CompatibleTo(shape));
1104   }
1105 
1106  private:
1107   Impl impl_;
1108   ShapeType** matched_shape_;
1109 };
1110 
1111 }  // namespace detail
1112 
1113 // Creates a shape pattern that will capture the matched layout in the argument.
1114 inline constexpr detail::ShapePattern<const ::xla::Shape,
1115                                       detail::ShapePatternBaseImpl>
1116 Shape(const ::xla::Shape** matched_shape = nullptr) {
1117   return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
1118       detail::ShapePatternBaseImpl(), matched_shape);
1119 }
1120 
1121 // Creates a shape pattern that will capture the matched layout in the argument.
1122 inline constexpr detail::ShapePattern<::xla::Shape,
1123                                       detail::ShapePatternBaseImpl>
1124 Shape(::xla::Shape** matched_shape) {
1125   return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
1126       detail::ShapePatternBaseImpl(), matched_shape);
1127 }
1128 
1129 namespace detail {
1130 
1131 // Overloads to get a const or non-const operand out of an instruction.
1132 inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) {
1133   return instr->mutable_operand(idx);
1134 }
1135 inline const HloInstruction* HloOperand(const HloInstruction* instr,
1136                                         int64 idx) {
1137   return instr->operand(idx);
1138 }
1139 
1140 // Pretty-printer for HloInstruction.  Sort of like ToShortString, but with
1141 // fewer %s and more shapes.
1142 inline string InstToString(const HloInstruction* inst) {
1143   return inst->ToString(
1144       HloPrintOptions().set_print_metadata(false).set_print_percent(false));
1145 }
1146 
1147 template <typename HloInstructionType, typename Impl>
1148 class HloInstructionPattern;
1149 
1150 // The base HloInstructionPattern implementation. Matches only if the
1151 // instruction is not nullptr.
1152 class HloInstructionPatternBaseImpl {
1153  public:
1154   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1155     if (inst == nullptr) {
1156       EXPLAIN << "HloInstruction* is null";
1157       return false;
1158     }
1159     return true;
1160   }
1161 
1162   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1163     *os << "an HloInstruction";
1164   }
1165 
1166   static constexpr bool kIsTrivialMatcher = true;
1167 };
1168 
1169 // An HloInstructionPattern implementation that matches only if the instruction
1170 // has a given name.
1171 class HloInstructionPatternNameImpl {
1172  public:
1173   explicit HloInstructionPatternNameImpl(absl::string_view name)
1174       : name_(name) {}
1175 
1176   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1177     if (inst->name() != name_) {
1178       EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
1179       return false;
1180     }
1181     return true;
1182   }
1183 
1184   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1185     *os << "named \"" << name_ << "\"";
1186   }
1187 
1188  private:
1189   absl::string_view name_;
1190 };
1191 
1192 // An HloInstructionPattern implementation that matches only if the instruction
1193 // equals a particular pointer.
1194 class HloInstructionIsImpl {
1195  public:
1196   explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
1197 
1198   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1199     if (inst != inst_) {
1200       EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " ("
1201               << InstToString(inst_) << ")";
1202       return false;
1203     }
1204     return true;
1205   }
1206 
1207   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1208     *os << "which is " << inst_ << " (" << InstToString(inst_) << ")";
1209   }
1210 
1211  private:
1212   const HloInstruction* inst_;
1213 };
1214 
1215 // An HloInstructionPattern implementation that matches only if the instruction
1216 // has a given opcode.
1217 class HloInstructionPatternOpcodeImpl {
1218  public:
1219   explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
1220                                                      bool invert)
1221       : opcode_(opcode), invert_(invert) {}
1222 
1223   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1224     if (invert_ && inst->opcode() == opcode_) {
1225       EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
1226               << ", expected anything else";
1227       return false;
1228     }
1229     if (!invert_ && inst->opcode() != opcode_) {
1230       EXPLAIN << "HloInstruction doesn't have opcode "
1231               << HloOpcodeString(opcode_);
1232       return false;
1233     }
1234     return true;
1235   }
1236 
1237   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1238     if (!invert_) {
1239       *os << "with opcode " << HloOpcodeString(opcode_);
1240     } else {
1241       *os << "with any opcode other than " << HloOpcodeString(opcode_);
1242     }
1243   }
1244 
1245  private:
1246   HloOpcode opcode_;
1247   bool invert_;
1248 };
1249 
1250 // An HloInstructionPattern implementation that matches only if the instruction
1251 // has the given number of operands.
1252 class HloInstructionPatternNumOperandsImpl {
1253  public:
1254   explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands)
1255       : num_operands_(num_operands) {}
1256 
1257   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1258     if (inst->operand_count() != num_operands_) {
1259       EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
1260       return false;
1261     }
1262     return true;
1263   }
1264 
1265   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1266     *os << "with " << num_operands_ << " operand"
1267         << (num_operands_ != 1 ? "s" : "");
1268   }
1269 
1270  private:
1271   int64 num_operands_;
1272 };
1273 
1274 // An HloInstructionPattern implementation that matches only if the instruction
1275 // has a shape that matches a given pattern.
1276 template <typename ShapeType, typename ShapeImpl>
1277 class HloInstructionPatternShapeImpl {
1278  public:
1279   explicit constexpr HloInstructionPatternShapeImpl(
1280       const ShapePattern<ShapeType, ShapeImpl>& shape)
1281       : shape_(shape) {}
1282 
1283   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1284     if (!shape_.Match(&inst->shape(), option)) {
1285       EXPLAIN << "\nin output shape";
1286       return false;
1287     }
1288     return true;
1289   }
1290 
1291   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1292     if (!shape_.Match(inst->mutable_shape(), option)) {
1293       EXPLAIN << "\nin output shape";
1294       return false;
1295     }
1296     return true;
1297   }
1298 
1299   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1300     *os << "outputting";
1301     Indent(os, indent + kIndentInc);
1302     shape_.DescribeTo(os, indent + kIndentInc);
1303   }
1304 
1305  private:
1306   ShapePattern<ShapeType, ShapeImpl> shape_;
1307 };
1308 
1309 // An HloInstructionPattern implementation that matches only if the instruction
1310 // has an operand that matches a given pattern.
1311 template <typename OperandType, typename OperandImpl>
1312 class HloInstructionPatternOperandImpl {
1313  public:
1314   explicit constexpr HloInstructionPatternOperandImpl(
1315       int64 operand_index,
1316       const HloInstructionPattern<OperandType, OperandImpl>& operand)
1317       : operand_index_(operand_index), operand_(operand) {}
1318 
1319   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1320     return MatchImpl(inst, option);
1321   }
1322 
1323   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1324     return MatchImpl(inst, option);
1325   }
1326 
1327   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1328     *os << "with operand " << operand_index_ << " which is:";
1329     Indent(os, indent + kIndentInc);
1330     operand_.DescribeTo(os, indent + kIndentInc);
1331   }
1332 
1333  private:
1334   template <typename HloInstructionType>
1335   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1336     if (operand_index_ >= inst->operand_count()) {
1337       EXPLAIN << "desired operand index " << operand_index_
1338               << " is out of bounds";
1339       return false;
1340     }
1341     if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1342       EXPLAIN << "\nin operand " << operand_index_;
1343       return false;
1344     }
1345     return true;
1346   }
1347 
1348   int64 operand_index_;
1349   HloInstructionPattern<OperandType, OperandImpl> operand_;
1350 };
1351 
1352 // Matches a binary instruction whose operands come in any order.
1353 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1354           typename OperandImpl2>
1355 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
1356  public:
1357   explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
1358       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1359       const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
1360       : op1_(op1), op2_(op2) {}
1361 
1362   bool Match(HloInstruction* inst, MatchOption option) const {
1363     return MatchImpl(inst, option);
1364   }
1365 
1366   bool Match(const HloInstruction* inst, MatchOption option) const {
1367     return MatchImpl(inst, option);
1368   }
1369 
1370   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1371     *os << "with two operands in either order:";
1372     Indent(os, indent);
1373     *os << " - ";
1374     op1_.DescribeTo(os, indent + 3);
1375     Indent(os, indent);
1376     *os << " - ";
1377     op2_.DescribeTo(os, indent + 3);
1378   }
1379 
1380  private:
1381   HloInstruction* operand(HloInstruction* inst, int64 idx) const {
1382     return inst->mutable_operand(idx);
1383   }
1384   const HloInstruction* operand(const HloInstruction* inst, int64 idx) const {
1385     return inst->operand(idx);
1386   }
1387 
1388   template <typename HloInstructionType>
1389   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1390     // We could implement this using AnyOf and AllOf matchers, but the templates
1391     // get pretty difficult to debug, since any compile error herein becomes
1392     // not-an-error via SFINAE.  Also this way lets us give better messages on
1393     // failure.
1394     if (inst->operand_count() != 2) {
1395       EXPLAIN << "HloInstruction did not have two operands";
1396       return false;
1397     }
1398 
1399     // If we're not generating explanations, this is pretty simple.
1400     if (!option.explain_os) {
1401       auto try_match = [&](int64 idx1, int64 idx2) {
1402         MatchOption new_option = option;
1403         new_option.capture = false;
1404         if (op1_.Match(operand(inst, idx1), new_option) &&
1405             op2_.Match(operand(inst, idx2), new_option)) {
1406           if (option.capture) {
1407             bool matched = op1_.Match(operand(inst, idx1), option) &&
1408                            op2_.Match(operand(inst, idx2), option);
1409             DCHECK(matched);
1410           }
1411           return true;
1412         }
1413         return false;
1414       };
1415       return try_match(0, 1) || try_match(1, 0);
1416     }
1417 
1418     // If we are generating explanations, we have some work to do in order to
1419     // generate a helpful error.
1420     //
1421     // First, try all four operand/matcher combinations, recording the
1422     // failure explanations separately from option.explain_os. matches[i][j]
1423     // tells us if matcher_i matches operand j.
1424     bool matches[/*matcher*/ 2][/*operand*/ 2];
1425     std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
1426     for (int i = 0; i < 2; ++i) {
1427       for (int j = 0; j < 2; ++j) {
1428         MatchOption new_option = option;
1429         new_option.capture = false;
1430         new_option.explain_os = &explanations[i][j];
1431         matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
1432                                : op2_.Match(operand(inst, j), new_option);
1433       }
1434     }
1435 
1436     // Check if the match succeeded.
1437     for (int i = 0; i < 2; ++i) {
1438       if (matches[0][i] && matches[1][(i + 1) % 2]) {
1439         // Rerun the matches with capture enabled if necessary.
1440         if (option.capture) {
1441           auto* operand1 = operand(inst, i);
1442           auto* operand2 = operand(inst, (i + 1) % 2);
1443           bool matched =
1444               op1_.Match(operand1, option) && op2_.Match(operand2, option);
1445           DCHECK(matched);
1446         }
1447         return true;
1448       }
1449     }
1450 
1451     auto describe_matcher = [&](int matcher_idx) {
1452       EXPLAIN << "\n - ";
1453       if (matcher_idx == 0) {
1454         op1_.DescribeTo(option.explain_os, /*indent=*/3);
1455       } else {
1456         CHECK_EQ(matcher_idx, 1);
1457         op2_.DescribeTo(option.explain_os, /*indent=*/3);
1458       }
1459       for (int i = 0; i < 2; ++i) {
1460         if (matches[matcher_idx][/*operand*/ i]) {
1461           continue;
1462         }
1463         EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
1464         EXPLAIN << " - ";
1465         EXPLAIN << absl::StrReplaceAll(
1466             explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n   "}});
1467       }
1468     };
1469 
1470     // If we failed to match, one of the following is true:
1471     //  1. op1 (op2) matches neither LHS nor RHS, or
1472     //  2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
1473     // We print different explanations depending on which case we're in.
1474 
1475     // Case 1.
1476     bool wrote_explanation = false;
1477     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1478       if (!matches[i][0] && !matches[i][1]) {
1479         EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
1480                 << (i == 0 ? "first" : "second") << " matcher.  Specifically,";
1481         describe_matcher(i);
1482         wrote_explanation = true;
1483       }
1484     }
1485 
1486     // Case 2.
1487     for (int i = 0; !wrote_explanation && i < 2; ++i) {
1488       if (matches[/*matcher*/ 0][/*operand*/ i] &&
1489           matches[/*matcher*/ 1][/*operand*/ i]) {
1490         CHECK(!matches[0][(i + 1) % 2]);
1491         CHECK(!matches[1][(i + 1) % 2]);
1492         CHECK(!wrote_explanation);
1493         EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
1494                 << " operand did not match either of the two matchers.  "
1495                    "Specifically,";
1496         describe_matcher(0);
1497         EXPLAIN << "\nand";
1498         describe_matcher(1);
1499         wrote_explanation = true;
1500       }
1501     }
1502 
1503     CHECK(wrote_explanation);
1504     return false;
1505   }
1506 
1507   HloInstructionPattern<OperandType1, OperandImpl1> op1_;
1508   HloInstructionPattern<OperandType2, OperandImpl2> op2_;
1509 };
1510 
1511 // An HloInstructionPattern implementation that matches only if the instruction
1512 // is a fusion node with a particular kind.
1513 class HloInstructionPatternFusionKindImpl {
1514  public:
1515   explicit constexpr HloInstructionPatternFusionKindImpl(
1516       ::xla::HloInstruction::FusionKind kind)
1517       : kind_(kind) {}
1518 
1519   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1520     return MatchImpl(inst, option);
1521   }
1522 
1523   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1524     return MatchImpl(inst, option);
1525   }
1526 
1527   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1528     *os << "with fusion kind " << ToString(kind_);
1529   }
1530 
1531  private:
1532   template <typename HloInstructionType>
1533   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1534     if (inst->opcode() != HloOpcode::kFusion) {
1535       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
1536               << "; it's not a fusion";
1537       return false;
1538     }
1539     if (inst->fusion_kind() != kind_) {
1540       EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
1541       return false;
1542     }
1543     return true;
1544   }
1545 
1546   ::xla::HloInstruction::FusionKind kind_;
1547 };
1548 
1549 // An HloInstructionPattern implementation that matches only if the instruction
1550 // is a kGetTupleElement with a particular tuple index.
1551 class HloInstructionPatternTupleIndexImpl {
1552  public:
1553   explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index)
1554       : tuple_index_(tuple_index) {}
1555 
1556   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1557     return MatchImpl(inst, option);
1558   }
1559 
1560   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1561     return MatchImpl(inst, option);
1562   }
1563 
1564   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1565     *os << "which is a GTE with index " << tuple_index_;
1566   }
1567 
1568  private:
1569   template <typename HloInstructionType>
1570   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1571     if (inst->opcode() != HloOpcode::kGetTupleElement) {
1572       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
1573               << "; it's not a GTE at all";
1574       return false;
1575     }
1576     if (inst->tuple_index() != tuple_index_) {
1577       EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
1578       return false;
1579     }
1580     return true;
1581   }
1582 
1583   int64 tuple_index_;
1584 };
1585 
1586 class HloInstructionPatternParameterNumImpl {
1587  public:
1588   explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num)
1589       : parameter_num_(parameter_num) {}
1590 
1591   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1592     return MatchImpl(inst, option);
1593   }
1594 
1595   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1596     return MatchImpl(inst, option);
1597   }
1598 
1599   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1600     *os << "which is parameter " << parameter_num_;
1601   }
1602 
1603  private:
1604   template <typename HloInstructionType>
1605   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1606     if (inst->opcode() != HloOpcode::kParameter ||
1607         inst->parameter_number() != parameter_num_) {
1608       EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
1609       return false;
1610     }
1611     return true;
1612   }
1613 
1614   int64 parameter_num_;
1615 };
1616 
1617 // Superclass that contains common code used by Op::WithOneUse() and
1618 // Op::WithOneUser().
1619 class HloInstructionPatternOneUseOrUserImpl {
1620  protected:
1621   bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
1622     if (inst->user_count() != 1) {
1623       EXPLAIN << "HloInstruction has " << inst->user_count()
1624               << " users, but expected exactly one.";
1625       if (inst->user_count() > 1) {
1626         EXPLAIN << "\nAll users:";
1627         for (const HloInstruction* user : inst->users()) {
1628           EXPLAIN << "\n - " << InstToString(user);
1629         }
1630       }
1631       return false;
1632     }
1633     return true;
1634   }
1635 };
1636 
1637 class HloInstructionPatternOneUseImpl
1638     : public HloInstructionPatternOneUseOrUserImpl {
1639  public:
1640   bool Match(const HloInstruction* inst, MatchOption option) const {
1641     if (!MatchOneUser(inst, option)) {
1642       return false;
1643     }
1644 
1645     int64 use_count = absl::c_count_if(
1646         inst->users()[0]->operands(),
1647         [&](const HloInstruction* operand) { return operand == inst; });
1648     if (use_count != 1) {
1649       EXPLAIN << "HloInstruction is used " << use_count
1650               << " times by its user, but is expected to be used just once: "
1651               << InstToString(inst->users()[0]);
1652       return false;
1653     }
1654     return true;
1655   }
1656 
1657   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1658     *os << "which has exactly one use";
1659   }
1660 };
1661 
1662 class HloInstructionPatternOneUserImpl
1663     : public HloInstructionPatternOneUseOrUserImpl {
1664  public:
1665   bool Match(const HloInstruction* inst, MatchOption option) const {
1666     return MatchOneUser(inst, option);
1667   }
1668 
1669   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1670     *os << "which has exactly one user (but possibly is used multiple times by "
1671            "that instruction)";
1672   }
1673 };
1674 
1675 class HloInstructionPatternComparisonDirectionImpl {
1676  public:
1677   explicit constexpr HloInstructionPatternComparisonDirectionImpl(
1678       ComparisonDirection direction)
1679       : direction_(direction) {}
1680 
1681   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1682     return MatchImpl(inst, option);
1683   }
1684 
1685   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1686     return MatchImpl(inst, option);
1687   }
1688 
1689   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1690     *os << "which has comparison direction "
1691         << ComparisonDirectionToString(direction_);
1692   }
1693 
1694  private:
1695   template <typename HloInstructionType>
1696   bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1697     if (inst->opcode() != HloOpcode::kCompare ||
1698         inst->comparison_direction() != direction_) {
1699       EXPLAIN << "HloInstruction is not comparison "
1700               << ComparisonDirectionToString(direction_);
1701       return false;
1702     }
1703     return true;
1704   }
1705 
1706   ComparisonDirection direction_;
1707 };
1708 
1709 // Matches a constant scalar or effective scalar, optionally with a given value.
1710 template <typename ScalarTy>
1711 class HloConstantScalarImpl {
1712  public:
1713   explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
1714       : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {}
1715 
1716   constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
1717       : val_(val), match_effective_scalar_(match_effective_scalar) {}
1718 
1719   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1720     return MatchImpl(inst, option);
1721   }
1722 
1723   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1724     return MatchImpl(inst, option);
1725   }
1726 
1727   void DescribeTo(std::ostream* os, int64 indent = 0) const {
1728     *os << "which is a constant "
1729         << (match_effective_scalar_ ? "effective " : "") << "scalar";
1730     if (val_.has_value()) {
1731       *os << " with value " << *val_;
1732     }
1733   }
1734 
1735  private:
1736   template <typename InstTy>
1737   bool MatchImpl(InstTy* inst, MatchOption option) const {
1738     const auto* const_inst = DynCast<HloConstantInstruction>(inst);
1739     if (!const_inst) {
1740       EXPLAIN << "HloInstruction is not a constant";
1741       return false;
1742     }
1743     if (match_effective_scalar_ &&
1744         !ShapeUtil::IsEffectiveScalar(inst->shape())) {
1745       EXPLAIN << "HloInstruction is not an effective scalar";
1746       return false;
1747     }
1748     if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
1749       EXPLAIN << "HloInstruction is not a scalar";
1750       return false;
1751     }
1752     if (!val_.has_value()) {
1753       return true;
1754     }
1755 
1756     // Check that literal == static_cast<LitearlTy>(val) and
1757     // val == static_cast<ValTy>(literal).  This is sufficient to ensure that
1758     // the two constant scalars are actually "equal".
1759     auto val_literal = LiteralUtil::CreateR0(*val_);
1760     auto literal_r0_or = const_inst->literal().Reshape({});
1761     auto val_as_literal_ty_or =
1762         val_literal.Convert(const_inst->shape().element_type());
1763     if (!literal_r0_or.ok() || !val_as_literal_ty_or.ok()) {
1764       EXPLAIN << "could not construct relevant Literals (how did this happen?)";
1765       return false;
1766     }
1767     auto literal_r0 = std::move(literal_r0_or).ValueOrDie();
1768     auto val_as_literal_ty = std::move(val_as_literal_ty_or).ValueOrDie();
1769     auto literal_r0_as_val_ty_or =
1770         literal_r0.Convert(val_literal.shape().element_type());
1771     bool rv = literal_r0_as_val_ty_or.ok() &&  //
1772               literal_r0_as_val_ty_or.ValueOrDie() == val_literal &&
1773               literal_r0 == val_as_literal_ty;
1774     if (!rv) {
1775       EXPLAIN << "HloInstruction's constant value "
1776               << literal_r0.ToStringWithoutShape()
1777               << " did not match expected value " << *val_;
1778     }
1779     return rv;
1780   }
1781 
1782   absl::optional<ScalarTy> val_;
1783   bool match_effective_scalar_;
1784 };
1785 
1786 // A pattern that matches HloInstructions.
1787 template <typename HloInstructionType, typename Impl>
1788 class HloInstructionPattern {
1789  private:
1790   template <typename NewImpl>
1791   auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern<
1792       HloInstructionType, decltype(AllOf<HloInstruction>(
1793                               std::declval<Impl>(), std::move(new_impl)))> {
1794     auto new_allof = AllOf<HloInstruction>(impl_, std::move(new_impl));
1795     return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
1796         std::move(new_allof), matched_inst_);
1797   }
1798 
1799  public:
1800   explicit constexpr HloInstructionPattern(const Impl& impl,
1801                                            HloInstructionType** matched_inst)
1802       : impl_(impl), matched_inst_(matched_inst) {}
1803 
1804   // Returns true and captures the instruction iff it matches the pattern.
1805   bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1806     if (impl_.Match(inst, option)) {
1807       if (option.capture && matched_inst_) {
1808         *matched_inst_ = inst;
1809       }
1810       return true;
1811     }
1812     if (inst != nullptr) {
1813       EXPLAIN << "\nin " << InstToString(inst);
1814     }
1815     return false;
1816   }
1817 
1818   // Returns true and captures the instruction iff it matches the pattern.
1819   bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1820     if (impl_.Match(inst, option)) {
1821       if (option.capture && matched_inst_) {
1822         *matched_inst_ = inst;
1823       }
1824       return true;
1825     }
1826     EXPLAIN << "\nin " << InstToString(inst);
1827     return false;
1828   }
1829 
1830   // Modifies the pattern to match only if the instruction has the given name.
1831   auto WithName(absl::string_view name) const
1832       -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) {
1833     return AppendImpl(HloInstructionPatternNameImpl(name));
1834   }
1835 
1836   // Modifies the pattern to match only if the instruction has the given opcode.
1837   auto WithOpcode(HloOpcode opcode) const
1838       -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
1839                                                                    false))) {
1840     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
1841   }
1842 
1843   auto WithNumOperands(int64 num_operands) const -> decltype(
1844       this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) {
1845     return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
1846   }
1847 
1848   // Modifies the pattern to match only if the instruction does not have the
1849   // given opcode.
1850   auto WithoutOpcode(HloOpcode opcode) const
1851       -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
1852                                                                    true))) {
1853     return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
1854   }
1855 
1856   constexpr auto Is(const HloInstruction* instr) const
1857       -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) {
1858     return AppendImpl(HloInstructionIsImpl(instr));
1859   }
1860 
1861   // Modifies the pattern to match only if the instruction is a constant.
1862   constexpr auto IsConstant() const
1863       -> decltype(this->WithOpcode(HloOpcode::kConstant)) {
1864     return WithOpcode(HloOpcode::kConstant);
1865   }
1866 
1867   constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl(
1868       HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false))) {
1869     return AppendImpl(
1870         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
1871   }
1872 
1873   // This does not check that T has the same type as the instruction, so e.g.
1874   // IsConstantScalar(1.0) may match a constant of shape int32[].
1875   template <typename ScalarTy>
1876   constexpr auto IsConstantScalar(const ScalarTy& val) const
1877       -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>(
1878           val, /*match_effective_scalar=*/false))) {
1879     return AppendImpl(
1880         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
1881   }
1882 
1883   constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl(
1884       HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true))) {
1885     return AppendImpl(
1886         HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
1887   }
1888 
1889   template <typename ScalarTy>
1890   constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const
1891       -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>(
1892           val, /*match_effective_scalar=*/true))) {
1893     return AppendImpl(
1894         HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
1895   }
1896 
1897   // Modifies the pattern to match only if the instruction is not a constant.
1898   constexpr auto IsNonConstant() const
1899       -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) {
1900     return WithoutOpcode(HloOpcode::kConstant);
1901   }
1902 
1903   // Modifies the pattern to match only if the instruction has a shape that
1904   // matches the given pattern.
1905   template <typename ShapeType, typename ShapeImpl>
1906   constexpr auto WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape)
1907       const -> decltype(this->AppendImpl(
1908           HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape))) {
1909     return AppendImpl(
1910         HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
1911   }
1912 
1913   // Make this a templated function to work around gcc 4.9.4 template infinite
1914   // recursion bug.
1915   template <typename Dummy = void>
1916   constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const
1917       -> decltype(this->WithShape(Shape().EqualTo(shape))) {
1918     return WithShape(Shape().EqualTo(shape));
1919   }
1920 
1921   // Make this a templated function to work around gcc 4.9.4 template infinite
1922   // recursion bug.
1923   template <typename Dummy = void>
1924   constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const
1925       -> decltype(this->WithShape(Shape().CompatibleTo(shape))) {
1926     return WithShape(Shape().CompatibleTo(shape));
1927   }
1928 
1929   // Modifies the pattern to match only if the instruction has an operand that
1930   // matches the given pattern.
1931   template <typename OperandType, typename OperandImpl>
1932   constexpr auto WithOperand(
1933       int64 operand_index,
1934       const HloInstructionPattern<OperandType, OperandImpl>& operand) const
1935       -> decltype(this->AppendImpl(
1936           HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
1937               operand_index, operand))) {
1938     return AppendImpl(
1939         HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
1940             operand_index, operand));
1941   }
1942 
1943   template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1944             typename OperandImpl2>
1945   constexpr auto WithBinaryOperandsAnyOrder(
1946       const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1947       const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const
1948       -> decltype(this->AppendImpl(
1949           HloInstructionPatternBinaryOperandsAnyOrderImpl<
1950               OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1,
1951                                                                       op2))) {
1952     return AppendImpl(
1953         HloInstructionPatternBinaryOperandsAnyOrderImpl<
1954             OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
1955   }
1956 
1957   // Modifies the pattern to match only if the instruction is a fusion node with
1958   // the given kind.
1959   constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const
1960       -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) {
1961     return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
1962   }
1963 
1964   // Modifies the pattern to match only if the instruction is a
1965   // get-tuple-element with the given tuple index.
1966   constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype(
1967       this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) {
1968     return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
1969   }
1970 
1971   // Modifies the pattern to match only if the instruction is a parameter
1972   // with the given parameter number.
1973   constexpr auto WithParameterNum(int64 parameter_num) const -> decltype(
1974       this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) {
1975     return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
1976   }
1977 
1978   // Modifies the pattern to match if the instruction is used exactly once.
1979   // Does not match if the instruction is used twice by the same user (e.g.
1980   // multiply(x,x)).
1981   constexpr auto WithOneUse() const
1982       -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) {
1983     return AppendImpl(HloInstructionPatternOneUseImpl());
1984   }
1985 
1986   // Modifies the pattern to match if the instruction is used by exactly one
1987   // other instruction.  Will match if the instruction is used twice, so long as
1988   // it's by the same user (e.g.  multiply(x,x)).
1989   constexpr auto WithOneUser() const
1990       -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) {
1991     return AppendImpl(HloInstructionPatternOneUserImpl());
1992   }
1993 
1994   // Modifies the pattern to match only if the instruction has the given
1995   // comparison direction.
1996   auto WithComparisonDirection(ComparisonDirection direction) const
1997       -> decltype(this->AppendImpl(
1998           HloInstructionPatternComparisonDirectionImpl(direction))) {
1999     return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
2000   }
2001 
2002   void DescribeTo(std::ostream* os, int64 indent = 0) const {
2003     impl_.DescribeTo(os, indent);
2004   }
2005 
2006  private:
2007   Impl impl_;
2008   HloInstructionType** matched_inst_;
2009 };
2010 
2011 }  // namespace detail
2012 
2013 // Creates an instruction pattern that will capture the matched instruction in
2014 // the argument.
2015 inline constexpr detail::HloInstructionPattern<
2016     const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
2017 Op(const ::xla::HloInstruction** matched_inst = nullptr) {
2018   return detail::HloInstructionPattern<const ::xla::HloInstruction,
2019                                        detail::HloInstructionPatternBaseImpl>(
2020       detail::HloInstructionPatternBaseImpl(), matched_inst);
2021 }
2022 
2023 // Creates an instruction pattern that will capture the matched instruction in
2024 // the argument.
2025 inline constexpr detail::HloInstructionPattern<
2026     ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
2027 Op(::xla::HloInstruction** matched_inst) {
2028   return detail::HloInstructionPattern<::xla::HloInstruction,
2029                                        detail::HloInstructionPatternBaseImpl>(
2030       detail::HloInstructionPatternBaseImpl(), matched_inst);
2031 }
2032 
2033 // Helpers for nullary instructions.
2034 #define XLA_NULLOP_PATTERN(NAME)                                      \
2035   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
2036     return Op().WithOpcode(HloOpcode::k##NAME);                       \
2037   }                                                                   \
2038                                                                       \
2039   template <typename HloInstructionType>                              \
2040   inline auto NAME(HloInstructionType** matched_inst)                 \
2041       ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) {   \
2042     return Op(matched_inst).WithOpcode(HloOpcode::k##NAME);           \
2043   }
2044 XLA_NULLOP_PATTERN(Constant)
2045 XLA_NULLOP_PATTERN(Parameter)
2046 XLA_NULLOP_PATTERN(Iota)
2047 XLA_NULLOP_PATTERN(Rng)
2048 #undef XLA_NULLOP_PATTERN
2049 
2050 // Helpers for unary instructions.
2051 #define XLA_UNOP_PATTERN(NAME)                                        \
2052   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
2053     return Op().WithOpcode(HloOpcode::k##NAME);                       \
2054   }                                                                   \
2055                                                                       \
2056   template <typename Arg>                                             \
2057   inline auto NAME(Arg&& arg)->decltype(                              \
2058       Op().WithOpcode(HloOpcode::k##NAME)                             \
2059           .WithOperand(0, std::forward<Arg>(arg))) {                  \
2060     return Op()                                                       \
2061         .WithOpcode(HloOpcode::k##NAME)                               \
2062         .WithOperand(0, std::forward<Arg>(arg));                      \
2063   }                                                                   \
2064                                                                       \
2065   template <typename HloInstructionType, typename Arg>                \
2066   inline auto NAME(HloInstructionType** matched_inst, Arg&& arg)      \
2067       ->decltype(Op(matched_inst)                                     \
2068                      .WithOpcode(HloOpcode::k##NAME)                  \
2069                      .WithOperand(0, std::forward<Arg>(arg))) {       \
2070     return Op(matched_inst)                                           \
2071         .WithOpcode(HloOpcode::k##NAME)                               \
2072         .WithOperand(0, std::forward<Arg>(arg));                      \
2073   }
2074 XLA_UNOP_PATTERN(Abs)
2075 XLA_UNOP_PATTERN(RoundNearestAfz)
2076 XLA_UNOP_PATTERN(Bitcast)
2077 XLA_UNOP_PATTERN(Broadcast)
2078 XLA_UNOP_PATTERN(Ceil)
2079 XLA_UNOP_PATTERN(Convert)
2080 XLA_UNOP_PATTERN(Copy)
2081 XLA_UNOP_PATTERN(Cos)
2082 XLA_UNOP_PATTERN(AllReduce)
2083 XLA_UNOP_PATTERN(Exp)
2084 XLA_UNOP_PATTERN(Fft)
2085 XLA_UNOP_PATTERN(Floor)
2086 XLA_UNOP_PATTERN(GetTupleElement)
2087 XLA_UNOP_PATTERN(Imag)
2088 XLA_UNOP_PATTERN(Infeed)
2089 XLA_UNOP_PATTERN(IsFinite)
2090 XLA_UNOP_PATTERN(Log)
2091 XLA_UNOP_PATTERN(Not)
2092 XLA_UNOP_PATTERN(Negate)
2093 XLA_UNOP_PATTERN(Real)
2094 XLA_UNOP_PATTERN(Recv)
2095 XLA_UNOP_PATTERN(RecvDone)
2096 XLA_UNOP_PATTERN(ReducePrecision)
2097 XLA_UNOP_PATTERN(Reshape)
2098 XLA_UNOP_PATTERN(Reverse)
2099 XLA_UNOP_PATTERN(Rsqrt)
2100 XLA_UNOP_PATTERN(SendDone)
2101 XLA_UNOP_PATTERN(Sign)
2102 XLA_UNOP_PATTERN(Sin)
2103 XLA_UNOP_PATTERN(Slice)
2104 XLA_UNOP_PATTERN(Sqrt)
2105 XLA_UNOP_PATTERN(Tanh)
2106 XLA_UNOP_PATTERN(Transpose)
2107 #undef XLA_UNOP_PATTERN
2108 
2109 // Helpers for binary instructions.
2110 #define XLA_BINOP_PATTERN(NAME)                                             \
2111   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) {       \
2112     return Op().WithOpcode(HloOpcode::k##NAME);                             \
2113   }                                                                         \
2114                                                                             \
2115   template <typename Lhs, typename Rhs>                                     \
2116   inline auto NAME(Lhs&& lhs, Rhs&& rhs)                                    \
2117       ->decltype(Op().WithOpcode(HloOpcode::k##NAME)                        \
2118                      .WithOperand(0, std::forward<Lhs>(lhs))                \
2119                      .WithOperand(1, std::forward<Rhs>(rhs))) {             \
2120     return Op()                                                             \
2121         .WithOpcode(HloOpcode::k##NAME)                                     \
2122         .WithOperand(0, std::forward<Lhs>(lhs))                             \
2123         .WithOperand(1, std::forward<Rhs>(rhs));                            \
2124   }                                                                         \
2125                                                                             \
2126   template <typename HloInstructionType, typename Lhs, typename Rhs>        \
2127   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
2128       ->decltype(Op(matched_inst)                                           \
2129                      .WithOpcode(HloOpcode::k##NAME)                        \
2130                      .WithOperand(0, std::forward<Lhs>(lhs))                \
2131                      .WithOperand(1, std::forward<Rhs>(rhs))) {             \
2132     return Op(matched_inst)                                                 \
2133         .WithOpcode(HloOpcode::k##NAME)                                     \
2134         .WithOperand(0, std::forward<Lhs>(lhs))                             \
2135         .WithOperand(1, std::forward<Rhs>(rhs));                            \
2136   }
2137 
2138 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME)                                 \
2139   XLA_BINOP_PATTERN(NAME)                                                   \
2140                                                                             \
2141   template <typename HloInstructionType, typename Lhs, typename Rhs>        \
2142   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs,  \
2143                              Rhs&& rhs)                                     \
2144       ->decltype(Op(matched_inst)                                           \
2145                      .WithOpcode(HloOpcode::k##NAME)                        \
2146                      .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),    \
2147                                                  std::forward<Rhs>(rhs))) { \
2148     return Op(matched_inst)                                                 \
2149         .WithOpcode(HloOpcode::k##NAME)                                     \
2150         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                 \
2151                                     std::forward<Rhs>(rhs));                \
2152   }                                                                         \
2153   template <typename Lhs, typename Rhs>                                     \
2154   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs)                          \
2155       ->decltype(NAME##AnyOrder<const HloInstruction>(                      \
2156           nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) {       \
2157     return NAME##AnyOrder<const HloInstruction>(                            \
2158         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));           \
2159   }
2160 XLA_COMMUTATIVE_BINOP_PATTERN(Add)
2161 XLA_BINOP_PATTERN(Atan2)
2162 XLA_BINOP_PATTERN(Divide)
2163 XLA_BINOP_PATTERN(Complex)
2164 XLA_BINOP_PATTERN(Compare)
2165 XLA_BINOP_PATTERN(Convolution)
2166 XLA_BINOP_PATTERN(Dot)
2167 XLA_BINOP_PATTERN(Gather)
2168 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
2169 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
2170 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
2171 XLA_BINOP_PATTERN(Outfeed)
2172 XLA_BINOP_PATTERN(Pad)
2173 XLA_BINOP_PATTERN(Power)
2174 XLA_BINOP_PATTERN(ReduceWindow)
2175 XLA_BINOP_PATTERN(Remainder)
2176 XLA_BINOP_PATTERN(Send)
2177 XLA_BINOP_PATTERN(Subtract)
2178 XLA_COMMUTATIVE_BINOP_PATTERN(And)
2179 XLA_COMMUTATIVE_BINOP_PATTERN(Or)
2180 XLA_BINOP_PATTERN(ShiftLeft)
2181 XLA_BINOP_PATTERN(ShiftRightArithmetic)
2182 XLA_BINOP_PATTERN(ShiftRightLogical)
2183 #undef XLA_COMMUTATIVE_BINOP_PATTERN
2184 #undef XLA_BINOP_PATTERN
2185 
2186 // Helpers for ternary instructions.
2187 #define XLA_TERNOP_PATTERN(NAME)                                       \
2188   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) {  \
2189     return Op().WithOpcode(HloOpcode::k##NAME);                        \
2190   }                                                                    \
2191                                                                        \
2192   template <typename Arg0, typename Arg1, typename Arg2>               \
2193   inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2)              \
2194       ->decltype(Op().WithOpcode(HloOpcode::k##NAME)                   \
2195                      .WithOperand(0, std::forward<Arg0>(arg0))         \
2196                      .WithOperand(1, std::forward<Arg1>(arg1))         \
2197                      .WithOperand(2, std::forward<Arg2>(arg2))) {      \
2198     return Op()                                                        \
2199         .WithOpcode(HloOpcode::k##NAME)                                \
2200         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2201         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2202         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2203   }                                                                    \
2204                                                                        \
2205   template <typename HloInstructionType, typename Arg0, typename Arg1, \
2206             typename Arg2>                                             \
2207   inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0,     \
2208                    Arg1&& arg1, Arg2&& arg2)                           \
2209       ->decltype(Op(matched_inst)                                      \
2210                      .WithOpcode(HloOpcode::k##NAME)                   \
2211                      .WithOperand(0, std::forward<Arg0>(arg0))         \
2212                      .WithOperand(1, std::forward<Arg1>(arg1))         \
2213                      .WithOperand(2, std::forward<Arg2>(arg2))) {      \
2214     return Op(matched_inst)                                            \
2215         .WithOpcode(HloOpcode::k##NAME)                                \
2216         .WithOperand(0, std::forward<Arg0>(arg0))                      \
2217         .WithOperand(1, std::forward<Arg1>(arg1))                      \
2218         .WithOperand(2, std::forward<Arg2>(arg2));                     \
2219   }
2220 XLA_TERNOP_PATTERN(Clamp);
2221 XLA_TERNOP_PATTERN(Scatter);
2222 XLA_TERNOP_PATTERN(Select);
2223 #undef XLA_TERNOP_PATTERN
2224 
2225 namespace detail {
2226 template <typename Matcher, typename FirstArg>
2227 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg)
2228     -> decltype(m.WithOperand(operand_num, std::forward<FirstArg>(first_arg))) {
2229   return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
2230 }
2231 
2232 template <typename Matcher, typename FirstArg, typename... Args>
2233 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
2234                          Args&&... args)
2235     -> decltype(WithOperands(m.WithOperand(operand_num,
2236                                            std::forward<FirstArg>(first_arg)),
2237                              operand_num + 1, std::forward<Args>(args)...)) {
2238   return WithOperands(
2239       m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
2240       operand_num + 1, std::forward<Args>(args)...);
2241 }
2242 }  // namespace detail
2243 
2244 #define XLA_VARIADIC_OP_PATTERN(NAME)                                         \
2245   inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) {         \
2246     return Op().WithOpcode(HloOpcode::k##NAME);                               \
2247   }                                                                           \
2248                                                                               \
2249   template <typename... Args>                                                 \
2250   inline auto NAME(Args&&... args)                                            \
2251       ->decltype(detail::WithOperands(Op().WithOpcode(HloOpcode::k##NAME)     \
2252                                           .WithNumOperands(sizeof...(Args)),  \
2253                                       0, std::forward<Args>(args)...)) {      \
2254     return detail::WithOperands(                                              \
2255         Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
2256         /*operand_num=*/0, std::forward<Args>(args)...);                      \
2257   }                                                                           \
2258                                                                               \
2259   template <typename HloInstructionType, typename... Args>                    \
2260   inline auto NAME(HloInstructionType** matched_inst, Args&&... args)         \
2261       ->decltype(detail::WithOperands(Op(matched_inst)                        \
2262                                           .WithOpcode(HloOpcode::k##NAME)     \
2263                                           .WithNumOperands(sizeof...(Args)),  \
2264                                       0, std::forward<Args>(args)...)) {      \
2265     return detail::WithOperands(Op(matched_inst)                              \
2266                                     .WithOpcode(HloOpcode::k##NAME)           \
2267                                     .WithNumOperands(sizeof...(Args)),        \
2268                                 /*operand_num=*/0,                            \
2269                                 std::forward<Args>(args)...);                 \
2270   }
2271 
2272 // We could implement all ops as "variadic" ops, but it would make the
2273 // already-bad compile errors even worse.
2274 XLA_VARIADIC_OP_PATTERN(AfterAll);
2275 XLA_VARIADIC_OP_PATTERN(Concatenate);
2276 XLA_VARIADIC_OP_PATTERN(CustomCall);
2277 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
2278 XLA_VARIADIC_OP_PATTERN(Map)
2279 XLA_VARIADIC_OP_PATTERN(Reduce);
2280 XLA_VARIADIC_OP_PATTERN(Sort);
2281 XLA_VARIADIC_OP_PATTERN(Tuple);
2282 
2283 // Helpers for comparison instructions.
2284 #define XLA_COMPARE_PATTERN(NAME)                                              \
2285   inline auto NAME()->decltype(                                                \
2286       Op().WithOpcode(HloOpcode::kCompare)                                     \
2287           .WithComparisonDirection(ComparisonDirection::k##NAME)) {            \
2288     return Op()                                                                \
2289         .WithOpcode(HloOpcode::kCompare)                                       \
2290         .WithComparisonDirection(ComparisonDirection::k##NAME);                \
2291   }                                                                            \
2292                                                                                \
2293   template <typename Lhs, typename Rhs>                                        \
2294   inline auto NAME(Lhs&& lhs, Rhs&& rhs)                                       \
2295       ->decltype(Op().WithOpcode(HloOpcode::kCompare)                          \
2296                      .WithOperand(0, std::forward<Lhs>(lhs))                   \
2297                      .WithOperand(1, std::forward<Rhs>(rhs))                   \
2298                      .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
2299     return Op()                                                                \
2300         .WithOpcode(HloOpcode::kCompare)                                       \
2301         .WithOperand(0, std::forward<Lhs>(lhs))                                \
2302         .WithOperand(1, std::forward<Rhs>(rhs))                                \
2303         .WithComparisonDirection(ComparisonDirection::k##NAME);                \
2304   }                                                                            \
2305                                                                                \
2306   template <typename HloInstructionType, typename Lhs, typename Rhs>           \
2307   inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs)    \
2308       ->decltype(Op(matched_inst)                                              \
2309                      .WithOpcode(HloOpcode::kCompare)                          \
2310                      .WithOperand(0, std::forward<Lhs>(lhs))                   \
2311                      .WithOperand(1, std::forward<Rhs>(rhs))                   \
2312                      .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
2313     return Op(matched_inst)                                                    \
2314         .WithOpcode(HloOpcode::kCompare)                                       \
2315         .WithOperand(0, std::forward<Lhs>(lhs))                                \
2316         .WithOperand(1, std::forward<Rhs>(rhs))                                \
2317         .WithComparisonDirection(ComparisonDirection::k##NAME);                \
2318   }
2319 
2320 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME)                               \
2321   XLA_COMPARE_PATTERN(NAME)                                                 \
2322                                                                             \
2323   template <typename HloInstructionType, typename Lhs, typename Rhs>        \
2324   inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs,  \
2325                              Rhs&& rhs)                                     \
2326       ->decltype(Op(matched_inst)                                           \
2327                      .WithOpcode(HloOpcode::kCompare)                       \
2328                      .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),    \
2329                                                  std::forward<Rhs>(rhs))) { \
2330     return Op(matched_inst)                                                 \
2331         .WithOpcode(HloOpcode::kCompare)                                    \
2332         .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs),                 \
2333                                     std::forward<Rhs>(rhs));                \
2334   }                                                                         \
2335   template <typename Lhs, typename Rhs>                                     \
2336   inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs)                          \
2337       ->decltype(NAME##AnyOrder<const HloInstruction>(                      \
2338           nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) {       \
2339     return NAME##AnyOrder<const HloInstruction>(                            \
2340         nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs));           \
2341   }
2342 
2343 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
2344 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
2345 XLA_COMPARE_PATTERN(Ge);
2346 XLA_COMPARE_PATTERN(Gt);
2347 XLA_COMPARE_PATTERN(Le);
2348 XLA_COMPARE_PATTERN(Lt);
2349 
2350 // Helpers for matching non-constant instructions.
2351 inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
2352   return Op().IsNonConstant();
2353 }
2354 
2355 template <typename HloInstructionType>
2356 inline auto NonConstant(HloInstructionType** matched_inst)
2357     -> decltype(Op(matched_inst).IsNonConstant()) {
2358   return Op(matched_inst).IsNonConstant();
2359 }
2360 
2361 // Add overloads for GetTupleElement which take a int64 specifying which tuple
2362 // element is selected.
2363 template <typename Arg>
2364 inline auto GetTupleElement(Arg&& arg, int64 tuple_index)
2365     -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement)
2366                     .WithOperand(0, std::forward<Arg>(arg))
2367                     .WithTupleIndex(tuple_index)) {
2368   return Op()
2369       .WithOpcode(HloOpcode::kGetTupleElement)
2370       .WithOperand(0, std::forward<Arg>(arg))
2371       .WithTupleIndex(tuple_index);
2372 }
2373 
2374 template <typename HloInstructionType, typename Arg>
2375 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
2376                             int64 tuple_index)
2377     -> decltype(Op(matched_inst)
2378                     .WithOpcode(HloOpcode::kGetTupleElement)
2379                     .WithOperand(0, std::forward<Arg>(arg))
2380                     .WithTupleIndex(tuple_index)) {
2381   return Op(matched_inst)
2382       .WithOpcode(HloOpcode::kGetTupleElement)
2383       .WithOperand(0, std::forward<Arg>(arg))
2384       .WithTupleIndex(tuple_index);
2385 }
2386 
2387 // Add overloads for Parameter which take an int64 specifying the parameter
2388 // number.
2389 inline auto Parameter(int64 parameter_num) -> decltype(
2390     Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) {
2391   return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
2392 }
2393 template <typename HloInstructionType>
2394 inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num)
2395     -> decltype(Op(matched_inst)
2396                     .WithOpcode(HloOpcode::kParameter)
2397                     .WithParameterNum(parameter_num)) {
2398   return Op(matched_inst)
2399       .WithOpcode(HloOpcode::kParameter)
2400       .WithParameterNum(parameter_num);
2401 }
2402 
2403 inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) {
2404   return Op().IsConstantScalar();
2405 }
2406 
2407 template <typename HloInstructionType>
2408 inline auto ConstantScalar(HloInstructionType** matched_inst)
2409     -> decltype(Op(matched_inst).IsConstantScalar()) {
2410   return Op(matched_inst).IsConstantScalar();
2411 }
2412 
2413 template <typename ScalarTy>
2414 inline auto ConstantScalar(ScalarTy val)
2415     -> decltype(Op().IsConstantScalar(val)) {
2416   return Op().IsConstantScalar(val);
2417 }
2418 
2419 template <typename HloInstructionType, typename ScalarTy>
2420 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val)
2421     -> decltype(Op(matched_inst).IsConstantScalar(val)) {
2422   return Op(matched_inst).IsConstantScalar(val);
2423 }
2424 
2425 inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) {
2426   return Op().IsConstantEffectiveScalar();
2427 }
2428 
2429 template <typename HloInstructionType>
2430 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst)
2431     -> decltype(Op(matched_inst).IsConstantScalar()) {
2432   return Op(matched_inst).IsConstantEffectiveScalar();
2433 }
2434 
2435 template <typename ScalarTy>
2436 inline auto ConstantEffectiveScalar(ScalarTy val)
2437     -> decltype(Op().IsConstantEffectiveScalar(val)) {
2438   return Op().IsConstantEffectiveScalar(val);
2439 }
2440 
2441 template <typename HloInstructionType, typename ScalarTy>
2442 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
2443                                     ScalarTy val)
2444     -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) {
2445   return Op(matched_inst).IsConstantEffectiveScalar(val);
2446 }
2447 
2448 }  // namespace match
2449 
2450 }  // namespace xla
2451 
2452 #undef EXPLAIN
2453 #pragma pop_macro("EXPLAIN")
2454 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
2455