1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
18
19 #include "absl/strings/str_replace.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/utility/utility.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29
30 namespace xla {
31
32 // A pattern matcher for HloInstructions, Shapes, and Layouts.
33 //
34 // The Match function's first argument must be HloInstruction*, Shape*, or
35 // Layout*. The second argument is a pattern that will be matched against the
36 // first argument, as described below.
37 //
38 // Patterns are constructed using the match::Op, match::Shape, or match::Layout
39 // functions. By default, the returned patterns will match any HloInstruction,
40 // Shape, or Layout, respectively. However the match can be made more specific
41 // by using the pattern's modifier methods, for example:
42 //
43 // match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
44 // 0, match::Op().WithOpcode(HloOpcode::kConstant))
45 //
46 // This pattern will match Add instructions whose first operand is a constant.
47 //
48 // Each pattern type has the following modifiers, which are described where
49 // nontrivial.
50 //
51 // Op():
52 // - Is: is the given HloInstruction* (i.e. pointer equality)
53 // - WithName
54 // - WithOpcode
55 // - WithoutOpcode: anything other than the given opcode
56 // - WithShape: instr's shape matches the given pattern
57 // - WithShapeEqualTo: instr's shape is equal to the given Shape
58 // - WithShapeCompatibleTo: instr's shape is compatible with the given Shape
59 // - WithNumOperands
60 // - WithOperand: operand at the given index matches the given pattern
61 // - IsConstant
62 // - IsNonConstant
63 // - IsConstantScalar/IsEffectiveConstantScalar: Optionally accepts a value,
64 // e.g. IsConstantScalar() or IsConstantScalar(42).
65 // - WithFusionKind
66 // - WithTupleIndex: get-tuple-element operations with the given tuple index
67 // - WithOneUse: Instruction is used as an operand exactly once.
68 // - WithOneUser: Instruction is used by exactly one other instruction, but
69 // is possibly used more than once as an operand (e.g. multiply(x,x)).
70 // - WithComparisonDirection: instr has the given direction
71 //
72 // Shape():
73 // - EqualTo
74 // - CompatibleTo
75 // - IsScalar/IsEffectiveScalar/IsArray/IsTuple
76 // - IsDenseArray/IsSparseArray
77 // - WithLayout: layout shape's layout matches the given pattern (e.g.
78 // Layout().WithDenseFormat())
79 // - WithLayoutEqualTo: shape's layout equals the argument (i.e. another
80 // Layout, but not the result of Layout().foo())
81 // - WithSubshape: shape is a tuple whose subshape matches the given pattern
82 // (e.g. Shape().IsScalar()).
83 // - WithSubshapeEqualTo: shape is a tuple with a subshape equal to the arg
84 // (i.e. another Shape, but not the result of Shape().foo())
85 // - WithElementType: shape is an array/scalar with the given elem type
86 // - WithRank: shape is an array/scalar with the given rank
87 //
88 // Layout():
89 // - EqualTo
90 // - WithDenseFormat/WithSparseFormat
91 //
92 // Op(), Shape(), and Layout() may be passed an argument of type
93 // HloInstruction**, Shape**, or Layout**, respectively, or const versions of
94 // these pointers. If the pattern is matched, the address of the matched value
95 // will be "captured" and stored at this location.
96 //
97 // For example:
98 // HloInstruction* foo = ...;
99 // HloInstruction* matched_operand;
100 // CHECK(Match(foo,
101 // match::Op().WithOperand(0, match::Op(&matched_operand))));
102 //
103 // Helpers are provided for most HLO instructions. These helpers can be called
104 // with no arguments, in which case they will match any instruction matching the
105 // opcode. They may also be called with matches for the operands and with an
106 // optional capture. (The capture must be the first argument.) Some examples of
107 // these helpers and their equivalents are provided below.
108
109 // Example nullary instruction:
110 // Parameter() == Op().WithOpcode(HloOpcode::kParameter)
111 // Parameter(&a) == Op(&a).WithOpcode(HloOpcode::kParameter)
112 //
113 // Example unary instruction:
114 // Abs() == Op().WithOpcode(HloOpcode::kAbs)
115 // Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs)
116 // .WithOperand(0, Op(&a)))
117 // Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs)
118 // .WithOperand(0, Op(&b))
119 //
120 // Commutative binary instructions have a special form that accepts either order
121 // of args, e.g.:
122 //
123 // AddAnyOrder(Parameter(1), Abs()) ==
124 // Op().WithOpcode(HloOpcode::kAdd)
125 // .WithBinaryOperandsAnyOrder(Op().WithParameterNum(1), Abs());
126 //
127 // MultiplyAnyOrder(&a, Parameter(), Abs()) // Captures the mul in `a`.
128 //
129 // The following additional helpers are provided. In all cases, `&a` is
130 // optional.
131 //
132 // ConstantScalar(&a) == Op(&a).IsConstantScalar();
133 // ConstantScalar(&a, v) == Op(&a).IsConstantScalar(v);
134 // ConstantEffectiveScalar(&a) == Op(&a).IsConstantEffectiveScalar();
135 // ConstantEffectiveScalar(&a, v) == Op(&a).IsConstantEffectiveScalar(&a, v)
136 // NonConstant(&a) == Op(&a).IsNonConstant()
137 // GetTupleElement(&a, b, index) == Op(&a).WithTupleIndex(index)
138 // .WithOperand(0, b);
139 // Parameter(&a, n) == Op(&a).WithParameterNum(n);
140
141 struct MatchOption {
142 // If true, actually capture matched item into the user pointer.
143 bool capture;
144
145 // An explanation for why we failed to match is streamed here, if not-null.
146 std::ostream* explain_os;
147 };
148
149 template <typename Value, typename Pattern>
150 bool Match(Value* value, const Pattern& pattern,
151 MatchOption option = {/*.capture=*/true, /*.explain_os=*/nullptr}) {
152 if (option.capture) {
153 auto new_option = option;
154 new_option.capture = false;
155 if (!pattern.Match(value, new_option)) {
156 return false;
157 }
158 }
159 return pattern.Match(value, option);
160 }
161
162 namespace match {
163
164 namespace detail {
165
166 // Macro for streaming to option.explain_os if it's not null.
167 //
168 // EXPLAIN << "value of foo(): " << foo()
169 //
170 #pragma push_macro("EXPLAIN")
171 #define EXPLAIN \
172 if (option.explain_os) *option.explain_os
173
174 // kIndentInc is the additional number of spaces that we indent by when we
175 // increase the indent "by one".
176 enum {
177 kIndentInc = 2,
178 };
179
180 // Writes a newline and then `indent` spaces.
181 //
182 // We follow an unintuitive convention in this file's pretty-printers: Indents
183 // are performed by the caller, not the callee. For example, if you want to
184 // print
185 //
186 // foo:
187 // - bar
188 //
189 // you'd do:
190 //
191 // Foo::DescribeTo(std::ostream* os, int64 indent) {
192 // *os << "foo:";
193 // Indent(os, indent) // Create a newline at the *current* indent level.
194 // *os << " - ";
195 // bar.DescribeTo(os, indent + 3); // + 3 because strlen(" * ") == 3.
196 // }
197 //
198 // Bar::DescribeTo(std::ostream* os, int64 indent) { *os << "bar"; }
199 //
200 // Notice that Bar::DescribeTo() does not call Indent; the indenting is
201 // performed by Foo. This convention allows the caller to decide whether a
202 // matcher is preceded by a newline, which is important e.g. for the AllOf
203 // matcher.
204 //
205 // (Incidentally, indenting in Match's explanations is handled differently.
206 // Indents are a common case in DescribeTo [we're printing a whole tree], but
207 // they're a special case in Match [we're printing only a path through the tree
208 // that encounters a failing node]. Indents in Match only appear when we
209 // encounter a failing disjunction, so we just handle them as a special case
210 // there.)
Indent(std::ostream * os,int64 indent)211 inline void Indent(std::ostream* os, int64 indent) {
212 *os << "\n";
213 for (int64 i = 0; i < indent; ++i) {
214 *os << " ";
215 }
216 }
217
218 // SFINAE template that determines whether T declares a static member
219 // kIsTrivialMatcher.
220 //
221 // Trivial matchers get special treatment. For example, when printing
222 // a conjunction of matchers, we don't print "and" after a trivial matcher. This
223 // yields e.g.
224 // "a shape compatible with f32[1,2]"
225 // rather than
226 // "a shape AND compatible with f32[1,2]"
227 template <typename T, typename Dummy = void>
228 struct IsTrivialMatcher {
229 static constexpr bool value = false;
230 };
231 template <typename T>
232 struct IsTrivialMatcher<T,
233 typename std::enable_if<T::kIsTrivialMatcher>::type> {
234 static constexpr bool value = true;
235 };
236
237 template <typename Item, typename... Patterns>
238 class AllOfPattern {
239 public:
240 explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
241
242 bool Match(const Item* item, MatchOption option) const {
243 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
244 // This invariant is guaranteed by the top-level Match and AnyOf.
245 DCHECK(matched || !option.capture);
246 return matched;
247 }
248
249 bool Match(Item* item, MatchOption option) const {
250 bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
251 // This invariant is guaranteed by the top-level Match and AnyOf.
252 DCHECK(matched || !option.capture);
253 return matched;
254 }
255
256 void DescribeTo(std::ostream* os, int64 indent = 0) const {
257 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
258 }
259
260 // Accessor for patterns_. Please don't use this outside of this file.
261 const std::tuple<Patterns...>& patterns() const { return patterns_; }
262
263 private:
264 template <typename ItemType, size_t index>
265 bool MatchImpl(ItemType* item, MatchOption option,
266 std::integral_constant<size_t, index>) const {
267 // We don't need to do any EXPLAINing here; it's all correctly handled by
268 // our sub-matchers (if any fail).
269 return std::get<index>(patterns_).Match(item, option) &&
270 MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
271 }
272
273 template <typename ItemType>
274 bool MatchImpl(ItemType* item, MatchOption option,
275 std::integral_constant<size_t, sizeof...(Patterns)>) const {
276 return true;
277 }
278
279 // Pretty-printing a conjunction has some special cases to make it easy to
280 // read in the simple (common) case.
281 //
282 // If sizeof...(Patterns) == 1, prints as e.g.
283 //
284 // a shape
285 //
286 // If sizeof...(Patterns) == 2 and patterns_[0] is a trivial matcher (e.g. "a
287 // shape") prints as
288 //
289 // a shape compatible with f32[1,2]
290 //
291 // If sizeof...(Patterns) > 2 and patterns_[0] is a trivial matcher, prints as
292 //
293 // a shape:
294 // * compatible with f32[1,2] AND
295 // * that represents a scalar
296 //
297 // Otherwise prints as:
298 //
299 // all of:
300 // * foo AND
301 // * bar
302 //
303 template <size_t index>
304 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
305 int64 indent) const {
306 constexpr bool first_is_trivial =
307 IsTrivialMatcher<typename std::remove_reference<decltype(
308 std::get<0>(patterns_))>::type>::value;
309 constexpr bool is_last = index == sizeof...(Patterns) - 1;
310 const auto& submatcher = std::get<index>(patterns_);
311
312 auto print_bulleted_item = [&] {
313 *os << " * ";
314 submatcher.DescribeTo(os, indent + 3);
315 if (!is_last) {
316 *os << " AND";
317 Indent(os, indent);
318 }
319 };
320
321 if (index == 0) {
322 if (first_is_trivial || is_last) {
323 submatcher.DescribeTo(os, indent + kIndentInc);
324 if (sizeof...(Patterns) > 2) {
325 *os << ":";
326 Indent(os, indent);
327 }
328 } else {
329 *os << "all of:";
330 Indent(os, indent);
331 print_bulleted_item();
332 }
333 } else if (first_is_trivial && index == 1 && sizeof...(Patterns) == 2) {
334 *os << " ";
335 submatcher.DescribeTo(os, indent);
336 } else {
337 print_bulleted_item();
338 }
339 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
340 }
341
342 void DescribeToImpl(std::ostream* os,
343 std::integral_constant<size_t, sizeof...(Patterns)>,
344 int64 indent) const {}
345
346 std::tuple<Patterns...> patterns_;
347 };
348
349 } // namespace detail
350
351 // Returns a pattern that represents the conjunction of all input patterns. All
352 // patterns need to match in order to have the AllOf pattern match.
353 template <typename Item, typename... Patterns>
354 detail::AllOfPattern<typename std::remove_const<Item>::type, Patterns...> AllOf(
355 const Patterns&... patterns) {
356 return detail::AllOfPattern<typename std::remove_const<Item>::type,
357 Patterns...>(patterns...);
358 }
359
360 // AllOf<AllOf<A, B...>, X, Y, ...> => AllOf<A, B, ..., X, Y, ...>.
361 //
362 // This transformation is necessary for good pretty-printing.
363 template <typename Item, typename... InnerPs, typename... OuterPs>
364 detail::AllOfPattern<typename std::remove_const<Item>::type, InnerPs...,
365 OuterPs...>
366 AllOf(const detail::AllOfPattern<Item, InnerPs...>& inner_p,
367 const OuterPs&... outer_ps) {
368 // Invoke constructor of AllOfPattern<Item, InnerPs..., OuterPs...>.
369 auto make_all_of = [](const InnerPs&... inner_ps,
370 const OuterPs&... outer_ps) {
371 return detail::AllOfPattern<typename std::remove_const<Item>::type,
372 InnerPs..., OuterPs...>(inner_ps...,
373 outer_ps...);
374 };
375 return absl::apply(make_all_of, std::tuple_cat(inner_p.patterns(),
376 std::make_tuple(outer_ps...)));
377 }
378
379 namespace detail {
380
381 template <typename LayoutType, typename Impl>
382 class LayoutPattern;
383
384 // The base LayoutPattern implementation. Matches only if the layout is not
385 // nullptr.
386 class LayoutPatternBaseImpl {
387 public:
388 bool Match(const ::xla::Layout* layout, MatchOption option) const {
389 if (layout == nullptr) {
390 EXPLAIN << "Layout is null";
391 return false;
392 }
393 return true;
394 }
395
396 void DescribeTo(std::ostream* os, int64 indent = 0) const {
397 *os << "a layout";
398 }
399
400 static constexpr bool kIsTrivialMatcher = true;
401 };
402
403 // A LayoutPattern implementation that matches only if the layout equals a
404 // Layout proto.
405 class LayoutPatternEqualImpl {
406 public:
407 explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
408 : layout_(layout) {}
409
410 bool Match(const ::xla::Layout* layout, MatchOption option) const {
411 if (!LayoutUtil::Equal(*layout_, *layout)) {
412 EXPLAIN << "Layout " << LayoutUtil::HumanString(*layout)
413 << " is not equal to expected "
414 << LayoutUtil::HumanString(*layout_);
415 return false;
416 }
417 return true;
418 }
419
420 void DescribeTo(std::ostream* os, int64 indent = 0) const {
421 *os << "equal to " << LayoutUtil::HumanString(*layout_);
422 }
423
424 private:
425 const ::xla::Layout* layout_;
426 };
427
428 // A LayoutPattern implementation that matches only if the layout has a given
429 // format.
430 class LayoutPatternFormatImpl {
431 public:
432 explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
433
434 bool Match(const ::xla::Layout* layout, MatchOption option) const {
435 if (layout->format() != format_) {
436 EXPLAIN << "Layout has format " << Format_Name(layout->format())
437 << " but expected " << Format_Name(format_);
438 return false;
439 }
440 return true;
441 }
442
443 void DescribeTo(std::ostream* os, int64 indent = 0) const {
444 *os << "with format " << Format_Name(format_);
445 }
446
447 private:
448 Format format_;
449 };
450
451 // A pattern that matches Layouts.
452 template <typename LayoutType, typename Impl>
453 class LayoutPattern {
454 private:
455 template <typename NewImpl>
456 auto AppendImpl(NewImpl new_impl) const
457 -> LayoutPattern<LayoutType,
458 decltype(AllOf<Layout>(std::declval<Impl>(),
459 std::move(new_impl)))> {
460 auto new_allof = AllOf<Layout>(impl_, std::move(new_impl));
461 return LayoutPattern<LayoutType, decltype(new_allof)>(std::move(new_allof),
462 matched_layout_);
463 }
464
465 public:
466 explicit constexpr LayoutPattern(const Impl& impl,
467 LayoutType** matched_layout)
468 : impl_(impl), matched_layout_(matched_layout) {}
469
470 // Returns true and captures the layout iff it matches the pattern.
471 bool Match(const ::xla::Layout* layout, MatchOption option) const {
472 if (impl_.Match(layout, option)) {
473 if (option.capture && matched_layout_) {
474 *matched_layout_ = layout;
475 }
476 return true;
477 }
478 return false;
479 }
480
481 // Returns true and captures the layout iff it matches the pattern.
482 bool Match(::xla::Layout* layout, MatchOption option) const {
483 if (impl_.Match(layout, option)) {
484 if (option.capture && matched_layout_) {
485 *matched_layout_ = layout;
486 }
487 return true;
488 }
489 return false;
490 }
491
492 void DescribeTo(std::ostream* os, int64 indent = 0) const {
493 impl_.DescribeTo(os, indent);
494 }
495
496 // Modifies the pattern to match only if the layout equals the given proto.
497 // The layout must outlive the returned pattern.
498 constexpr auto EqualTo(const ::xla::Layout* layout) const
499 -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) {
500 return AppendImpl(LayoutPatternEqualImpl(layout));
501 }
502
503 // Modifies the pattern to match only if the layout has a dense format.
504 constexpr auto WithDenseFormat() const
505 -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) {
506 return AppendImpl(LayoutPatternFormatImpl(DENSE));
507 }
508
509 // Modifies the pattern to match only if the layout has a sparse format.
510 constexpr auto WithSparseFormat() const
511 -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) {
512 return AppendImpl(LayoutPatternFormatImpl(SPARSE));
513 }
514
515 private:
516 Impl impl_;
517 LayoutType** matched_layout_;
518 };
519
520 template <typename Item, typename... Patterns>
521 class AnyOfPattern {
522 public:
523 explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
524
525 bool Match(const Item* item, MatchOption option) const {
526 return MatchImpl(item, option);
527 }
528
529 bool Match(Item* item, MatchOption option) const {
530 return MatchImpl(item, option);
531 }
532
533 void DescribeTo(std::ostream* os, int64 indent = 0) const {
534 *os << "any of:";
535 Indent(os, indent);
536 DescribeToImpl(os, std::integral_constant<size_t, 0>(), indent);
537 }
538
539 private:
540 template <typename ItemType>
541 bool MatchImpl(ItemType* item, MatchOption option) const {
542 // If we're generating an explanation, buffer it until we know we failed.
543 absl::optional<std::stringstream> explanation;
544 MatchOption new_option = option;
545 if (option.explain_os) {
546 new_option.explain_os = &explanation.emplace();
547 }
548 bool rv = MatchRecursiveImpl(item, new_option,
549 std::integral_constant<size_t, 0>());
550 if (!rv && option.explain_os) {
551 EXPLAIN << "None of the following matchers succeeded:";
552 EXPLAIN << explanation->str();
553 }
554 return rv;
555 }
556
557 template <typename ItemType, size_t index>
558 bool MatchRecursiveImpl(ItemType* item, MatchOption option,
559 std::integral_constant<size_t, index>) const {
560 auto new_option = option;
561 new_option.capture = false;
562
563 absl::optional<std::stringstream> explanation;
564 if (option.explain_os) {
565 new_option.explain_os = &explanation.emplace();
566 }
567
568 // Try to match the sub-pattern without capturing behavior.
569 if (std::get<index>(patterns_).Match(item, new_option)) {
570 // Capture the branch.
571 if (option.capture) {
572 // TODO(timshen): Currently the behavior can be exponential. Optimize it
573 // with memoization or recording the matched sub-pattern index, if it
574 // takes too long to run.
575 //
576 // Specifically, the "memoization" approach is to create an empty
577 // container with the key (pattern, instruction), and value as whether
578 // matched or not.
579 //
580 // Alternatively, we may run the pattern matching with captures off, but
581 // instead record a "trace" somewhere, indicating how exactly the
582 // pattern matches the input. For example, the trace information for
583 // AnyOf will be a runtime number indicate which sub-pattern is matched.
584 // Then we run another pass to do captures only with the help of the
585 // trace.
586 bool matched = std::get<index>(patterns_).Match(item, option);
587 DCHECK(matched);
588 }
589 return true;
590 }
591 if (option.explain_os) {
592 EXPLAIN << "\nMatcher #" << index + 1;
593 EXPLAIN << "\n - ";
594 std::get<index>(patterns_).DescribeTo(option.explain_os, /*indent=*/3);
595 EXPLAIN << "\nfailed with";
596 EXPLAIN << "\n - ";
597 EXPLAIN << absl::StrReplaceAll(explanation->str(), {{"\n", "\n "}});
598 }
599 return MatchRecursiveImpl(item, option,
600 std::integral_constant<size_t, index + 1>());
601 }
602
603 template <typename ItemType>
604 bool MatchRecursiveImpl(
605 ItemType* item, MatchOption option,
606 std::integral_constant<size_t, sizeof...(Patterns)>) const {
607 return false;
608 }
609
610 template <size_t index>
611 void DescribeToImpl(std::ostream* os, std::integral_constant<size_t, index>,
612 int64 indent) const {
613 *os << " - ";
614 std::get<index>(patterns_).DescribeTo(os, indent + 3);
615 if (index != sizeof...(Patterns) - 1) {
616 *os << " OR";
617 Indent(os, indent);
618 }
619 DescribeToImpl(os, std::integral_constant<size_t, index + 1>(), indent);
620 }
621
622 void DescribeToImpl(std::ostream* os,
623 std::integral_constant<size_t, sizeof...(Patterns)>,
624 int64 indent) const {}
625
626 std::tuple<Patterns...> patterns_;
627 };
628
629 } // namespace detail
630
631 // Returns a pattern that represents the logical disjunction of the input
632 // patterns. The returned pattern matches from left to right, and stops on the
633 // first match.
634 template <typename Item, typename... Patterns>
635 detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf(
636 const Patterns&... patterns) {
637 return detail::AnyOfPattern<typename std::remove_const<Item>::type,
638 Patterns...>(patterns...);
639 }
640
641 // Creates a layout pattern that will capture the matched layout in the
642 // argument.
643 inline constexpr detail::LayoutPattern<const ::xla::Layout,
644 detail::LayoutPatternBaseImpl>
645 Layout(const ::xla::Layout** matched_layout = nullptr) {
646 return detail::LayoutPattern<const ::xla::Layout,
647 detail::LayoutPatternBaseImpl>(
648 detail::LayoutPatternBaseImpl(), matched_layout);
649 }
650
651 // Creates a layout pattern that will capture the matched layout in the
652 // argument.
653 inline constexpr detail::LayoutPattern<::xla::Layout,
654 detail::LayoutPatternBaseImpl>
655 Layout(::xla::Layout** matched_layout) {
656 return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
657 detail::LayoutPatternBaseImpl(), matched_layout);
658 }
659
660 namespace detail {
661
662 template <typename ShapeType, typename Impl>
663 class ShapePattern;
664
665 // The base ShapePattern implementation. Matches only if the shape is not
666 // nullptr.
667 class ShapePatternBaseImpl {
668 public:
669 bool Match(const ::xla::Shape* shape, MatchOption option) const {
670 if (shape == nullptr) {
671 EXPLAIN << "Shape is null";
672 }
673 return shape != nullptr;
674 }
675
676 void DescribeTo(std::ostream* os, int64 indent = 0) const {
677 *os << "a shape";
678 }
679
680 static constexpr bool kIsTrivialMatcher = true;
681 };
682
683 // A ShapePattern implementation that matches only if the shape equals a Shape
684 // proto.
685 class ShapePatternEqualImpl {
686 public:
687 explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
688 : shape_(shape) {}
689
690 bool Match(const ::xla::Shape* shape, MatchOption option) const {
691 if (!ShapeUtil::Equal(*shape_, *shape)) {
692 EXPLAIN << "Shape not equal to "
693 << ShapeUtil::HumanStringWithLayout(*shape_);
694 return false;
695 }
696 return true;
697 }
698
699 void DescribeTo(std::ostream* os, int64 indent = 0) const {
700 *os << "equal to " << ShapeUtil::HumanStringWithLayout(*shape_);
701 }
702
703 private:
704 const ::xla::Shape* shape_;
705 };
706
707 // A ShapePattern implementation that matches only if the shape is compatible to
708 // a Shape proto.
709 class ShapePatternCompatibleImpl {
710 public:
711 explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
712 : shape_(shape) {}
713
714 bool Match(const ::xla::Shape* shape, MatchOption option) const {
715 if (!ShapeUtil::Compatible(*shape_, *shape)) {
716 EXPLAIN << "Shape not compatible with "
717 << ShapeUtil::HumanString(*shape_);
718 return false;
719 }
720 return true;
721 }
722
723 void DescribeTo(std::ostream* os, int64 indent = 0) const {
724 *os << "compatible with " << ShapeUtil::HumanString(*shape_);
725 }
726
727 private:
728 const ::xla::Shape* shape_;
729 };
730
731 // A ShapePattern implementation that matches only if the shape has a given
732 // element type.
733 class ShapePatternElementTypeImpl {
734 public:
735 explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
736 : element_type_(element_type) {}
737
738 bool Match(const ::xla::Shape* shape, MatchOption option) const {
739 if (shape->element_type() != element_type_) {
740 EXPLAIN << "Shape does not have element type "
741 << PrimitiveType_Name(element_type_);
742 return false;
743 }
744 return true;
745 }
746
747 void DescribeTo(std::ostream* os, int64 indent = 0) const {
748 *os << "with element type " << PrimitiveType_Name(element_type_);
749 }
750
751 private:
752 PrimitiveType element_type_;
753 };
754
755 // A ShapePattern implementation that matches only if the shape is scalar.
756 class ShapePatternIsScalarImpl {
757 public:
758 explicit constexpr ShapePatternIsScalarImpl() {}
759
760 bool Match(const ::xla::Shape* shape, MatchOption option) const {
761 if (!ShapeUtil::IsScalar(*shape)) {
762 EXPLAIN << "Shape is not a scalar";
763 return false;
764 }
765 return true;
766 }
767
768 void DescribeTo(std::ostream* os, int64 indent = 0) const {
769 *os << "that represents a scalar";
770 }
771 };
772
773 // A ShapePattern implementation that matches only if the shape is an array
774 class ShapePatternIsArrayImpl {
775 public:
776 explicit constexpr ShapePatternIsArrayImpl() {}
777
778 bool Match(const ::xla::Shape* shape, MatchOption option) const {
779 if (!shape->IsArray()) {
780 EXPLAIN << "Shape is not an array";
781 return false;
782 }
783 return true;
784 }
785
786 void DescribeTo(std::ostream* os, int64 indent = 0) const {
787 *os << "that represents an array";
788 }
789 };
790
791 // A ShapePattern implementation that matches only if the shape is a tuple.
792 class ShapePatternIsTupleImpl {
793 public:
794 explicit constexpr ShapePatternIsTupleImpl() {}
795
796 bool Match(const ::xla::Shape* shape, MatchOption option) const {
797 if (!shape->IsTuple()) {
798 EXPLAIN << "Shape is not a tuple";
799 return false;
800 }
801 return true;
802 }
803
804 void DescribeTo(std::ostream* os, int64 indent = 0) const {
805 *os << "that represents a tuple";
806 }
807 };
808
809 // A ShapePattern implementation that matches only if the shape is an effective
810 // scalar.
811 class ShapePatternEffectiveScalarImpl {
812 public:
813 explicit constexpr ShapePatternEffectiveScalarImpl() {}
814
815 bool Match(const ::xla::Shape* shape, MatchOption option) const {
816 if (!ShapeUtil::IsEffectiveScalar(*shape)) {
817 EXPLAIN << "Shape is not an effective scalar";
818 return false;
819 }
820 return true;
821 }
822
823 void DescribeTo(std::ostream* os, int64 indent = 0) const {
824 *os << "that is an effective scalar";
825 }
826 };
827
828 // A ShapePattern implementation that matches only if the shape has a given
829 // rank.
830 class ShapePatternRankImpl {
831 public:
832 explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
833
834 bool Match(const ::xla::Shape* shape, MatchOption option) const {
835 if (shape->rank() != rank_) {
836 if (rank_ == 0) {
837 EXPLAIN << "Shape is not a scalar";
838 } else {
839 EXPLAIN << "Shape does not have rank " << rank_;
840 }
841 return false;
842 }
843 return true;
844 }
845
846 void DescribeTo(std::ostream* os, int64 indent = 0) const {
847 if (rank_ == 0) {
848 *os << "that is a scalar";
849 } else {
850 *os << "that has " << rank_ << " dimension" << (rank_ != 1 ? "s" : "");
851 }
852 }
853
854 private:
855 int64 rank_;
856 };
857
858 // A ShapePattern implementation that matches only if the shape has a layout
859 // that matches a given pattern.
860 template <typename LayoutType, typename LayoutImpl>
861 class ShapePatternLayoutImpl {
862 public:
863 explicit constexpr ShapePatternLayoutImpl(
864 const LayoutPattern<LayoutType, LayoutImpl>& layout)
865 : layout_(layout) {}
866
867 bool Match(const ::xla::Shape* shape, MatchOption option) const {
868 return LayoutUtil::HasLayout(*shape) &&
869 layout_.Match(&shape->layout(), option);
870 }
871
872 bool Match(Shape* shape, MatchOption option) const {
873 if (!LayoutUtil::HasLayout(*shape)) {
874 EXPLAIN << "Shape does not have a layout";
875 return false;
876 }
877 if (!layout_.Match(shape->mutable_layout(), option)) {
878 EXPLAIN << "\nin layout";
879 return false;
880 }
881 return true;
882 }
883
884 void DescribeTo(std::ostream* os, int64 indent = 0) const {
885 *os << "with";
886 Indent(os, indent + kIndentInc);
887 layout_.DescribeTo(os, indent + kIndentInc);
888 }
889
890 private:
891 LayoutPattern<LayoutType, LayoutImpl> layout_;
892 };
893
894 // A ShapePattern implementation that matches only if the shape has a subshape
895 // that matches a given pattern.
896 template <typename SubshapeType, typename SubshapeImpl>
897 class ShapePatternSubshapeImpl {
898 public:
899 explicit ShapePatternSubshapeImpl(
900 ShapeIndexView index,
901 const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
902 : index_(index), subshape_(subshape) {}
903
904 bool Match(const ::xla::Shape* shape, MatchOption option) const {
905 return MatchImpl(shape, option);
906 }
907
908 bool Match(::xla::Shape* shape, MatchOption option) const {
909 return MatchImpl(shape, option);
910 }
911
912 void DescribeTo(std::ostream* os, int64 indent = 0) const {
913 *os << "with subshape at index " << index_.ToString() << " which is";
914 Indent(os, indent + kIndentInc);
915 subshape_.DescribeTo(os, indent + kIndentInc);
916 }
917
918 private:
919 Shape* GetSubshape(Shape* shape) const {
920 return ShapeUtil::GetMutableSubshape(shape, index_);
921 }
922 const Shape* GetSubshape(const Shape* shape) const {
923 return &ShapeUtil::GetSubshape(*shape, index_);
924 }
925
926 template <typename ShapeType>
927 bool MatchImpl(ShapeType* shape, MatchOption option) const {
928 if (!ShapeUtil::IndexIsValid(*shape, index_)) {
929 EXPLAIN << "No subshape at " << index_.ToString();
930 return false;
931 }
932 if (!subshape_.Match(GetSubshape(shape), option)) {
933 EXPLAIN << "\nin subshape at " << index_.ToString();
934 return false;
935 }
936 return true;
937 }
938
939 ShapeIndexView index_;
940 ShapePattern<SubshapeType, SubshapeImpl> subshape_;
941 };
942
943 // A pattern that matches Shapes.
944 template <typename ShapeType, typename Impl>
945 class ShapePattern {
946 private:
947 template <typename NewImpl>
948 auto AppendImpl(NewImpl new_impl) const
949 -> ShapePattern<ShapeType, decltype(AllOf<Shape>(std::declval<Impl>(),
950 std::move(new_impl)))> {
951 auto new_all_of = AllOf<Shape>(impl_, std::move(new_impl));
952 return ShapePattern<ShapeType, decltype(new_all_of)>(std::move(new_all_of),
953 matched_shape_);
954 }
955
956 public:
957 explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
958 : impl_(impl), matched_shape_(matched_shape) {}
959
960 // Returns true and captures the shape iff it matches the pattern.
961 bool Match(const ::xla::Shape* shape, MatchOption option) const {
962 if (impl_.Match(shape, option)) {
963 if (option.capture && matched_shape_) {
964 *matched_shape_ = shape;
965 }
966 return true;
967 }
968 if (shape) {
969 EXPLAIN << "\nin "
970 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
971 : ShapeUtil::HumanString(*shape));
972 }
973 return false;
974 }
975
976 // Returns true and captures the shape iff it matches the pattern.
977 bool Match(::xla::Shape* shape, MatchOption option) const {
978 if (impl_.Match(shape, option)) {
979 if (option.capture && matched_shape_) {
980 *matched_shape_ = shape;
981 }
982 return true;
983 }
984 EXPLAIN << "\nin "
985 << (shape->has_layout() ? ShapeUtil::HumanStringWithLayout(*shape)
986 : ShapeUtil::HumanString(*shape));
987 return false;
988 }
989
990 void DescribeTo(std::ostream* os, int64 indent = 0) const {
991 return impl_.DescribeTo(os, indent);
992 }
993
994 // Modifies the pattern to match only if the shape equals the given proto.
995 // The layout must outlive the returned pattern.
996 constexpr auto EqualTo(const ::xla::Shape* shape) const
997 -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) {
998 return AppendImpl(ShapePatternEqualImpl(shape));
999 }
1000
1001 // Modifies the pattern to match only if the shape is compatible to the given
1002 // proto. The layout must outlive the returned pattern.
1003 constexpr auto CompatibleTo(const ::xla::Shape* shape) const
1004 -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) {
1005 return AppendImpl(ShapePatternCompatibleImpl(shape));
1006 }
1007
1008 // Modifies the pattern to match only if the shape has the given element type.
1009 constexpr auto WithElementType(PrimitiveType element_type) const
1010 -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) {
1011 return AppendImpl(ShapePatternElementTypeImpl(element_type));
1012 }
1013
1014 // Modifies the pattern to match only if the shape is scalar.
1015 constexpr auto IsScalar() const
1016 -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) {
1017 return AppendImpl(ShapePatternIsScalarImpl());
1018 }
1019
1020 // Modifies the pattern to match only if the shape is an array.
1021 constexpr auto IsArray() const
1022 -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) {
1023 return AppendImpl(ShapePatternIsArrayImpl());
1024 }
1025
1026 // Modifies the pattern to match only if the shape is a tuple.
1027 constexpr auto IsTuple() const
1028 -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) {
1029 return AppendImpl(ShapePatternIsTupleImpl());
1030 }
1031
1032 constexpr auto IsEffectiveScalar() const
1033 -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) {
1034 return AppendImpl(ShapePatternEffectiveScalarImpl());
1035 }
1036
1037 // Modifies the pattern to match only if the shape has the given rank.
1038 constexpr auto WithRank(int64 rank) const
1039 -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) {
1040 return AppendImpl(ShapePatternRankImpl(rank));
1041 }
1042
1043 // Modifies the pattern to match only if the shape has a layout that matches
1044 // the given pattern.
1045 template <typename LayoutType, typename LayoutImpl>
1046 auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const
1047 -> decltype(this->AppendImpl(
1048 ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout))) {
1049 return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
1050 }
1051
1052 constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const
1053 -> decltype(this->WithLayout(Layout().EqualTo(layout))) {
1054 return WithLayout(Layout().EqualTo(layout));
1055 }
1056
1057 constexpr auto IsDenseArray() const
1058 -> decltype(this->WithLayout(Layout().WithDenseFormat())) {
1059 return WithLayout(Layout().WithDenseFormat());
1060 }
1061
1062 constexpr auto IsSparseArray() const
1063 -> decltype(this->WithLayout(Layout().WithSparseFormat())) {
1064 return WithLayout(Layout().WithSparseFormat());
1065 }
1066
1067 // Modifies the pattern to match only if the shape has a subshape that matches
1068 // the given pattern.
1069 template <typename SubshapeType, typename SubshapeImpl>
1070 auto WithSubshape(ShapeIndexView index,
1071 const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
1072 const -> decltype(this->AppendImpl(
1073 ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index,
1074 subshape))) {
1075 return AppendImpl(
1076 ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
1077 }
1078
1079 ShapePattern<ShapeType,
1080 AllOfPattern<Shape, Impl,
1081 ShapePatternSubshapeImpl<
1082 const ::xla::Shape,
1083 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1084 ShapePatternEqualImpl>>>>
1085 WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
1086 return WithSubshape(index,
1087 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1088 ShapePatternBaseImpl(), nullptr)
1089 .EqualTo(shape));
1090 }
1091
1092 ShapePattern<ShapeType,
1093 AllOfPattern<Shape, Impl,
1094 ShapePatternSubshapeImpl<
1095 const ::xla::Shape,
1096 AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
1097 ShapePatternCompatibleImpl>>>>
1098 WithSubshapeCompatibleTo(ShapeIndexView index,
1099 const ::xla::Shape* shape) const {
1100 return WithSubshape(index,
1101 ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
1102 ShapePatternBaseImpl(), nullptr)
1103 .CompatibleTo(shape));
1104 }
1105
1106 private:
1107 Impl impl_;
1108 ShapeType** matched_shape_;
1109 };
1110
1111 } // namespace detail
1112
1113 // Creates a shape pattern that will capture the matched layout in the argument.
1114 inline constexpr detail::ShapePattern<const ::xla::Shape,
1115 detail::ShapePatternBaseImpl>
1116 Shape(const ::xla::Shape** matched_shape = nullptr) {
1117 return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
1118 detail::ShapePatternBaseImpl(), matched_shape);
1119 }
1120
1121 // Creates a shape pattern that will capture the matched layout in the argument.
1122 inline constexpr detail::ShapePattern<::xla::Shape,
1123 detail::ShapePatternBaseImpl>
1124 Shape(::xla::Shape** matched_shape) {
1125 return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
1126 detail::ShapePatternBaseImpl(), matched_shape);
1127 }
1128
1129 namespace detail {
1130
1131 // Overloads to get a const or non-const operand out of an instruction.
1132 inline HloInstruction* HloOperand(HloInstruction* instr, int64 idx) {
1133 return instr->mutable_operand(idx);
1134 }
1135 inline const HloInstruction* HloOperand(const HloInstruction* instr,
1136 int64 idx) {
1137 return instr->operand(idx);
1138 }
1139
1140 // Pretty-printer for HloInstruction. Sort of like ToShortString, but with
1141 // fewer %s and more shapes.
1142 inline string InstToString(const HloInstruction* inst) {
1143 return inst->ToString(
1144 HloPrintOptions().set_print_metadata(false).set_print_percent(false));
1145 }
1146
1147 template <typename HloInstructionType, typename Impl>
1148 class HloInstructionPattern;
1149
1150 // The base HloInstructionPattern implementation. Matches only if the
1151 // instruction is not nullptr.
1152 class HloInstructionPatternBaseImpl {
1153 public:
1154 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1155 if (inst == nullptr) {
1156 EXPLAIN << "HloInstruction* is null";
1157 return false;
1158 }
1159 return true;
1160 }
1161
1162 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1163 *os << "an HloInstruction";
1164 }
1165
1166 static constexpr bool kIsTrivialMatcher = true;
1167 };
1168
1169 // An HloInstructionPattern implementation that matches only if the instruction
1170 // has a given name.
1171 class HloInstructionPatternNameImpl {
1172 public:
1173 explicit HloInstructionPatternNameImpl(absl::string_view name)
1174 : name_(name) {}
1175
1176 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1177 if (inst->name() != name_) {
1178 EXPLAIN << "HloInstruction not named \"" << name_ << "\"";
1179 return false;
1180 }
1181 return true;
1182 }
1183
1184 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1185 *os << "named \"" << name_ << "\"";
1186 }
1187
1188 private:
1189 absl::string_view name_;
1190 };
1191
1192 // An HloInstructionPattern implementation that matches only if the instruction
1193 // equals a particular pointer.
1194 class HloInstructionIsImpl {
1195 public:
1196 explicit HloInstructionIsImpl(const HloInstruction* inst) : inst_(inst) {}
1197
1198 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1199 if (inst != inst_) {
1200 EXPLAIN << "HloInstruction " << inst << " is not " << inst_ << " ("
1201 << InstToString(inst_) << ")";
1202 return false;
1203 }
1204 return true;
1205 }
1206
1207 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1208 *os << "which is " << inst_ << " (" << InstToString(inst_) << ")";
1209 }
1210
1211 private:
1212 const HloInstruction* inst_;
1213 };
1214
1215 // An HloInstructionPattern implementation that matches only if the instruction
1216 // has a given opcode.
1217 class HloInstructionPatternOpcodeImpl {
1218 public:
1219 explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
1220 bool invert)
1221 : opcode_(opcode), invert_(invert) {}
1222
1223 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1224 if (invert_ && inst->opcode() == opcode_) {
1225 EXPLAIN << "HloInstruction has opcode " << HloOpcodeString(opcode_)
1226 << ", expected anything else";
1227 return false;
1228 }
1229 if (!invert_ && inst->opcode() != opcode_) {
1230 EXPLAIN << "HloInstruction doesn't have opcode "
1231 << HloOpcodeString(opcode_);
1232 return false;
1233 }
1234 return true;
1235 }
1236
1237 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1238 if (!invert_) {
1239 *os << "with opcode " << HloOpcodeString(opcode_);
1240 } else {
1241 *os << "with any opcode other than " << HloOpcodeString(opcode_);
1242 }
1243 }
1244
1245 private:
1246 HloOpcode opcode_;
1247 bool invert_;
1248 };
1249
1250 // An HloInstructionPattern implementation that matches only if the instruction
1251 // has the given number of operands.
1252 class HloInstructionPatternNumOperandsImpl {
1253 public:
1254 explicit constexpr HloInstructionPatternNumOperandsImpl(int64 num_operands)
1255 : num_operands_(num_operands) {}
1256
1257 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1258 if (inst->operand_count() != num_operands_) {
1259 EXPLAIN << "HloInstruction doesn't have " << num_operands_ << " operands";
1260 return false;
1261 }
1262 return true;
1263 }
1264
1265 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1266 *os << "with " << num_operands_ << " operand"
1267 << (num_operands_ != 1 ? "s" : "");
1268 }
1269
1270 private:
1271 int64 num_operands_;
1272 };
1273
1274 // An HloInstructionPattern implementation that matches only if the instruction
1275 // has a shape that matches a given pattern.
1276 template <typename ShapeType, typename ShapeImpl>
1277 class HloInstructionPatternShapeImpl {
1278 public:
1279 explicit constexpr HloInstructionPatternShapeImpl(
1280 const ShapePattern<ShapeType, ShapeImpl>& shape)
1281 : shape_(shape) {}
1282
1283 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1284 if (!shape_.Match(&inst->shape(), option)) {
1285 EXPLAIN << "\nin output shape";
1286 return false;
1287 }
1288 return true;
1289 }
1290
1291 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1292 if (!shape_.Match(inst->mutable_shape(), option)) {
1293 EXPLAIN << "\nin output shape";
1294 return false;
1295 }
1296 return true;
1297 }
1298
1299 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1300 *os << "outputting";
1301 Indent(os, indent + kIndentInc);
1302 shape_.DescribeTo(os, indent + kIndentInc);
1303 }
1304
1305 private:
1306 ShapePattern<ShapeType, ShapeImpl> shape_;
1307 };
1308
1309 // An HloInstructionPattern implementation that matches only if the instruction
1310 // has an operand that matches a given pattern.
1311 template <typename OperandType, typename OperandImpl>
1312 class HloInstructionPatternOperandImpl {
1313 public:
1314 explicit constexpr HloInstructionPatternOperandImpl(
1315 int64 operand_index,
1316 const HloInstructionPattern<OperandType, OperandImpl>& operand)
1317 : operand_index_(operand_index), operand_(operand) {}
1318
1319 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1320 return MatchImpl(inst, option);
1321 }
1322
1323 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1324 return MatchImpl(inst, option);
1325 }
1326
1327 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1328 *os << "with operand " << operand_index_ << " which is:";
1329 Indent(os, indent + kIndentInc);
1330 operand_.DescribeTo(os, indent + kIndentInc);
1331 }
1332
1333 private:
1334 template <typename HloInstructionType>
1335 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1336 if (operand_index_ >= inst->operand_count()) {
1337 EXPLAIN << "desired operand index " << operand_index_
1338 << " is out of bounds";
1339 return false;
1340 }
1341 if (!operand_.Match(HloOperand(inst, operand_index_), option)) {
1342 EXPLAIN << "\nin operand " << operand_index_;
1343 return false;
1344 }
1345 return true;
1346 }
1347
1348 int64 operand_index_;
1349 HloInstructionPattern<OperandType, OperandImpl> operand_;
1350 };
1351
1352 // Matches a binary instruction whose operands come in any order.
1353 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1354 typename OperandImpl2>
1355 class HloInstructionPatternBinaryOperandsAnyOrderImpl {
1356 public:
1357 explicit constexpr HloInstructionPatternBinaryOperandsAnyOrderImpl(
1358 const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1359 const HloInstructionPattern<OperandType2, OperandImpl2>& op2)
1360 : op1_(op1), op2_(op2) {}
1361
1362 bool Match(HloInstruction* inst, MatchOption option) const {
1363 return MatchImpl(inst, option);
1364 }
1365
1366 bool Match(const HloInstruction* inst, MatchOption option) const {
1367 return MatchImpl(inst, option);
1368 }
1369
1370 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1371 *os << "with two operands in either order:";
1372 Indent(os, indent);
1373 *os << " - ";
1374 op1_.DescribeTo(os, indent + 3);
1375 Indent(os, indent);
1376 *os << " - ";
1377 op2_.DescribeTo(os, indent + 3);
1378 }
1379
1380 private:
1381 HloInstruction* operand(HloInstruction* inst, int64 idx) const {
1382 return inst->mutable_operand(idx);
1383 }
1384 const HloInstruction* operand(const HloInstruction* inst, int64 idx) const {
1385 return inst->operand(idx);
1386 }
1387
1388 template <typename HloInstructionType>
1389 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1390 // We could implement this using AnyOf and AllOf matchers, but the templates
1391 // get pretty difficult to debug, since any compile error herein becomes
1392 // not-an-error via SFINAE. Also this way lets us give better messages on
1393 // failure.
1394 if (inst->operand_count() != 2) {
1395 EXPLAIN << "HloInstruction did not have two operands";
1396 return false;
1397 }
1398
1399 // If we're not generating explanations, this is pretty simple.
1400 if (!option.explain_os) {
1401 auto try_match = [&](int64 idx1, int64 idx2) {
1402 MatchOption new_option = option;
1403 new_option.capture = false;
1404 if (op1_.Match(operand(inst, idx1), new_option) &&
1405 op2_.Match(operand(inst, idx2), new_option)) {
1406 if (option.capture) {
1407 bool matched = op1_.Match(operand(inst, idx1), option) &&
1408 op2_.Match(operand(inst, idx2), option);
1409 DCHECK(matched);
1410 }
1411 return true;
1412 }
1413 return false;
1414 };
1415 return try_match(0, 1) || try_match(1, 0);
1416 }
1417
1418 // If we are generating explanations, we have some work to do in order to
1419 // generate a helpful error.
1420 //
1421 // First, try all four operand/matcher combinations, recording the
1422 // failure explanations separately from option.explain_os. matches[i][j]
1423 // tells us if matcher_i matches operand j.
1424 bool matches[/*matcher*/ 2][/*operand*/ 2];
1425 std::stringstream explanations[/*matcher*/ 2][/*operand*/ 2];
1426 for (int i = 0; i < 2; ++i) {
1427 for (int j = 0; j < 2; ++j) {
1428 MatchOption new_option = option;
1429 new_option.capture = false;
1430 new_option.explain_os = &explanations[i][j];
1431 matches[i][j] = i == 0 ? op1_.Match(operand(inst, j), new_option)
1432 : op2_.Match(operand(inst, j), new_option);
1433 }
1434 }
1435
1436 // Check if the match succeeded.
1437 for (int i = 0; i < 2; ++i) {
1438 if (matches[0][i] && matches[1][(i + 1) % 2]) {
1439 // Rerun the matches with capture enabled if necessary.
1440 if (option.capture) {
1441 auto* operand1 = operand(inst, i);
1442 auto* operand2 = operand(inst, (i + 1) % 2);
1443 bool matched =
1444 op1_.Match(operand1, option) && op2_.Match(operand2, option);
1445 DCHECK(matched);
1446 }
1447 return true;
1448 }
1449 }
1450
1451 auto describe_matcher = [&](int matcher_idx) {
1452 EXPLAIN << "\n - ";
1453 if (matcher_idx == 0) {
1454 op1_.DescribeTo(option.explain_os, /*indent=*/3);
1455 } else {
1456 CHECK_EQ(matcher_idx, 1);
1457 op2_.DescribeTo(option.explain_os, /*indent=*/3);
1458 }
1459 for (int i = 0; i < 2; ++i) {
1460 if (matches[matcher_idx][/*operand*/ i]) {
1461 continue;
1462 }
1463 EXPLAIN << "\ndoes not match " << (i == 0 ? "LHS" : "RHS") << ":\n";
1464 EXPLAIN << " - ";
1465 EXPLAIN << absl::StrReplaceAll(
1466 explanations[matcher_idx][/*operand*/ i].str(), {{"\n", "\n "}});
1467 }
1468 };
1469
1470 // If we failed to match, one of the following is true:
1471 // 1. op1 (op2) matches neither LHS nor RHS, or
1472 // 2. op1 and op2 both match LHS (RHS), but neither matches RHS (LHS).
1473 // We print different explanations depending on which case we're in.
1474
1475 // Case 1.
1476 bool wrote_explanation = false;
1477 for (int i = 0; !wrote_explanation && i < 2; ++i) {
1478 if (!matches[i][0] && !matches[i][1]) {
1479 EXPLAIN << "HloInstruction's operands (ignoring order) did not match "
1480 << (i == 0 ? "first" : "second") << " matcher. Specifically,";
1481 describe_matcher(i);
1482 wrote_explanation = true;
1483 }
1484 }
1485
1486 // Case 2.
1487 for (int i = 0; !wrote_explanation && i < 2; ++i) {
1488 if (matches[/*matcher*/ 0][/*operand*/ i] &&
1489 matches[/*matcher*/ 1][/*operand*/ i]) {
1490 CHECK(!matches[0][(i + 1) % 2]);
1491 CHECK(!matches[1][(i + 1) % 2]);
1492 CHECK(!wrote_explanation);
1493 EXPLAIN << "HloInstruction's " << (i == 1 ? "LHS" : "RHS")
1494 << " operand did not match either of the two matchers. "
1495 "Specifically,";
1496 describe_matcher(0);
1497 EXPLAIN << "\nand";
1498 describe_matcher(1);
1499 wrote_explanation = true;
1500 }
1501 }
1502
1503 CHECK(wrote_explanation);
1504 return false;
1505 }
1506
1507 HloInstructionPattern<OperandType1, OperandImpl1> op1_;
1508 HloInstructionPattern<OperandType2, OperandImpl2> op2_;
1509 };
1510
1511 // An HloInstructionPattern implementation that matches only if the instruction
1512 // is a fusion node with a particular kind.
1513 class HloInstructionPatternFusionKindImpl {
1514 public:
1515 explicit constexpr HloInstructionPatternFusionKindImpl(
1516 ::xla::HloInstruction::FusionKind kind)
1517 : kind_(kind) {}
1518
1519 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1520 return MatchImpl(inst, option);
1521 }
1522
1523 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1524 return MatchImpl(inst, option);
1525 }
1526
1527 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1528 *os << "with fusion kind " << ToString(kind_);
1529 }
1530
1531 private:
1532 template <typename HloInstructionType>
1533 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1534 if (inst->opcode() != HloOpcode::kFusion) {
1535 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_)
1536 << "; it's not a fusion";
1537 return false;
1538 }
1539 if (inst->fusion_kind() != kind_) {
1540 EXPLAIN << "HloInstruction does not have fusion kind " << ToString(kind_);
1541 return false;
1542 }
1543 return true;
1544 }
1545
1546 ::xla::HloInstruction::FusionKind kind_;
1547 };
1548
1549 // An HloInstructionPattern implementation that matches only if the instruction
1550 // is a kGetTupleElement with a particular tuple index.
1551 class HloInstructionPatternTupleIndexImpl {
1552 public:
1553 explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index)
1554 : tuple_index_(tuple_index) {}
1555
1556 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1557 return MatchImpl(inst, option);
1558 }
1559
1560 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1561 return MatchImpl(inst, option);
1562 }
1563
1564 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1565 *os << "which is a GTE with index " << tuple_index_;
1566 }
1567
1568 private:
1569 template <typename HloInstructionType>
1570 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1571 if (inst->opcode() != HloOpcode::kGetTupleElement) {
1572 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_
1573 << "; it's not a GTE at all";
1574 return false;
1575 }
1576 if (inst->tuple_index() != tuple_index_) {
1577 EXPLAIN << "HloInstruction is not a GTE with index " << tuple_index_;
1578 return false;
1579 }
1580 return true;
1581 }
1582
1583 int64 tuple_index_;
1584 };
1585
1586 class HloInstructionPatternParameterNumImpl {
1587 public:
1588 explicit constexpr HloInstructionPatternParameterNumImpl(int64 parameter_num)
1589 : parameter_num_(parameter_num) {}
1590
1591 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1592 return MatchImpl(inst, option);
1593 }
1594
1595 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1596 return MatchImpl(inst, option);
1597 }
1598
1599 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1600 *os << "which is parameter " << parameter_num_;
1601 }
1602
1603 private:
1604 template <typename HloInstructionType>
1605 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1606 if (inst->opcode() != HloOpcode::kParameter ||
1607 inst->parameter_number() != parameter_num_) {
1608 EXPLAIN << "HloInstruction is not parameter " << parameter_num_;
1609 return false;
1610 }
1611 return true;
1612 }
1613
1614 int64 parameter_num_;
1615 };
1616
1617 // Superclass that contains common code used by Op::WithOneUse() and
1618 // Op::WithOneUser().
1619 class HloInstructionPatternOneUseOrUserImpl {
1620 protected:
1621 bool MatchOneUser(const HloInstruction* inst, MatchOption option) const {
1622 if (inst->user_count() != 1) {
1623 EXPLAIN << "HloInstruction has " << inst->user_count()
1624 << " users, but expected exactly one.";
1625 if (inst->user_count() > 1) {
1626 EXPLAIN << "\nAll users:";
1627 for (const HloInstruction* user : inst->users()) {
1628 EXPLAIN << "\n - " << InstToString(user);
1629 }
1630 }
1631 return false;
1632 }
1633 return true;
1634 }
1635 };
1636
1637 class HloInstructionPatternOneUseImpl
1638 : public HloInstructionPatternOneUseOrUserImpl {
1639 public:
1640 bool Match(const HloInstruction* inst, MatchOption option) const {
1641 if (!MatchOneUser(inst, option)) {
1642 return false;
1643 }
1644
1645 int64 use_count = absl::c_count_if(
1646 inst->users()[0]->operands(),
1647 [&](const HloInstruction* operand) { return operand == inst; });
1648 if (use_count != 1) {
1649 EXPLAIN << "HloInstruction is used " << use_count
1650 << " times by its user, but is expected to be used just once: "
1651 << InstToString(inst->users()[0]);
1652 return false;
1653 }
1654 return true;
1655 }
1656
1657 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1658 *os << "which has exactly one use";
1659 }
1660 };
1661
1662 class HloInstructionPatternOneUserImpl
1663 : public HloInstructionPatternOneUseOrUserImpl {
1664 public:
1665 bool Match(const HloInstruction* inst, MatchOption option) const {
1666 return MatchOneUser(inst, option);
1667 }
1668
1669 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1670 *os << "which has exactly one user (but possibly is used multiple times by "
1671 "that instruction)";
1672 }
1673 };
1674
1675 class HloInstructionPatternComparisonDirectionImpl {
1676 public:
1677 explicit constexpr HloInstructionPatternComparisonDirectionImpl(
1678 ComparisonDirection direction)
1679 : direction_(direction) {}
1680
1681 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1682 return MatchImpl(inst, option);
1683 }
1684
1685 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1686 return MatchImpl(inst, option);
1687 }
1688
1689 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1690 *os << "which has comparison direction "
1691 << ComparisonDirectionToString(direction_);
1692 }
1693
1694 private:
1695 template <typename HloInstructionType>
1696 bool MatchImpl(HloInstructionType* inst, MatchOption option) const {
1697 if (inst->opcode() != HloOpcode::kCompare ||
1698 inst->comparison_direction() != direction_) {
1699 EXPLAIN << "HloInstruction is not comparison "
1700 << ComparisonDirectionToString(direction_);
1701 return false;
1702 }
1703 return true;
1704 }
1705
1706 ComparisonDirection direction_;
1707 };
1708
1709 // Matches a constant scalar or effective scalar, optionally with a given value.
1710 template <typename ScalarTy>
1711 class HloConstantScalarImpl {
1712 public:
1713 explicit constexpr HloConstantScalarImpl(bool match_effective_scalar)
1714 : val_(absl::nullopt), match_effective_scalar_(match_effective_scalar) {}
1715
1716 constexpr HloConstantScalarImpl(ScalarTy val, bool match_effective_scalar)
1717 : val_(val), match_effective_scalar_(match_effective_scalar) {}
1718
1719 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1720 return MatchImpl(inst, option);
1721 }
1722
1723 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1724 return MatchImpl(inst, option);
1725 }
1726
1727 void DescribeTo(std::ostream* os, int64 indent = 0) const {
1728 *os << "which is a constant "
1729 << (match_effective_scalar_ ? "effective " : "") << "scalar";
1730 if (val_.has_value()) {
1731 *os << " with value " << *val_;
1732 }
1733 }
1734
1735 private:
1736 template <typename InstTy>
1737 bool MatchImpl(InstTy* inst, MatchOption option) const {
1738 const auto* const_inst = DynCast<HloConstantInstruction>(inst);
1739 if (!const_inst) {
1740 EXPLAIN << "HloInstruction is not a constant";
1741 return false;
1742 }
1743 if (match_effective_scalar_ &&
1744 !ShapeUtil::IsEffectiveScalar(inst->shape())) {
1745 EXPLAIN << "HloInstruction is not an effective scalar";
1746 return false;
1747 }
1748 if (!match_effective_scalar_ && !ShapeUtil::IsScalar(inst->shape())) {
1749 EXPLAIN << "HloInstruction is not a scalar";
1750 return false;
1751 }
1752 if (!val_.has_value()) {
1753 return true;
1754 }
1755
1756 // Check that literal == static_cast<LitearlTy>(val) and
1757 // val == static_cast<ValTy>(literal). This is sufficient to ensure that
1758 // the two constant scalars are actually "equal".
1759 auto val_literal = LiteralUtil::CreateR0(*val_);
1760 auto literal_r0_or = const_inst->literal().Reshape({});
1761 auto val_as_literal_ty_or =
1762 val_literal.Convert(const_inst->shape().element_type());
1763 if (!literal_r0_or.ok() || !val_as_literal_ty_or.ok()) {
1764 EXPLAIN << "could not construct relevant Literals (how did this happen?)";
1765 return false;
1766 }
1767 auto literal_r0 = std::move(literal_r0_or).ValueOrDie();
1768 auto val_as_literal_ty = std::move(val_as_literal_ty_or).ValueOrDie();
1769 auto literal_r0_as_val_ty_or =
1770 literal_r0.Convert(val_literal.shape().element_type());
1771 bool rv = literal_r0_as_val_ty_or.ok() && //
1772 literal_r0_as_val_ty_or.ValueOrDie() == val_literal &&
1773 literal_r0 == val_as_literal_ty;
1774 if (!rv) {
1775 EXPLAIN << "HloInstruction's constant value "
1776 << literal_r0.ToStringWithoutShape()
1777 << " did not match expected value " << *val_;
1778 }
1779 return rv;
1780 }
1781
1782 absl::optional<ScalarTy> val_;
1783 bool match_effective_scalar_;
1784 };
1785
1786 // A pattern that matches HloInstructions.
1787 template <typename HloInstructionType, typename Impl>
1788 class HloInstructionPattern {
1789 private:
1790 template <typename NewImpl>
1791 auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern<
1792 HloInstructionType, decltype(AllOf<HloInstruction>(
1793 std::declval<Impl>(), std::move(new_impl)))> {
1794 auto new_allof = AllOf<HloInstruction>(impl_, std::move(new_impl));
1795 return HloInstructionPattern<HloInstructionType, decltype(new_allof)>(
1796 std::move(new_allof), matched_inst_);
1797 }
1798
1799 public:
1800 explicit constexpr HloInstructionPattern(const Impl& impl,
1801 HloInstructionType** matched_inst)
1802 : impl_(impl), matched_inst_(matched_inst) {}
1803
1804 // Returns true and captures the instruction iff it matches the pattern.
1805 bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
1806 if (impl_.Match(inst, option)) {
1807 if (option.capture && matched_inst_) {
1808 *matched_inst_ = inst;
1809 }
1810 return true;
1811 }
1812 if (inst != nullptr) {
1813 EXPLAIN << "\nin " << InstToString(inst);
1814 }
1815 return false;
1816 }
1817
1818 // Returns true and captures the instruction iff it matches the pattern.
1819 bool Match(::xla::HloInstruction* inst, MatchOption option) const {
1820 if (impl_.Match(inst, option)) {
1821 if (option.capture && matched_inst_) {
1822 *matched_inst_ = inst;
1823 }
1824 return true;
1825 }
1826 EXPLAIN << "\nin " << InstToString(inst);
1827 return false;
1828 }
1829
1830 // Modifies the pattern to match only if the instruction has the given name.
1831 auto WithName(absl::string_view name) const
1832 -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) {
1833 return AppendImpl(HloInstructionPatternNameImpl(name));
1834 }
1835
1836 // Modifies the pattern to match only if the instruction has the given opcode.
1837 auto WithOpcode(HloOpcode opcode) const
1838 -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
1839 false))) {
1840 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
1841 }
1842
1843 auto WithNumOperands(int64 num_operands) const -> decltype(
1844 this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) {
1845 return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands));
1846 }
1847
1848 // Modifies the pattern to match only if the instruction does not have the
1849 // given opcode.
1850 auto WithoutOpcode(HloOpcode opcode) const
1851 -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
1852 true))) {
1853 return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
1854 }
1855
1856 constexpr auto Is(const HloInstruction* instr) const
1857 -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) {
1858 return AppendImpl(HloInstructionIsImpl(instr));
1859 }
1860
1861 // Modifies the pattern to match only if the instruction is a constant.
1862 constexpr auto IsConstant() const
1863 -> decltype(this->WithOpcode(HloOpcode::kConstant)) {
1864 return WithOpcode(HloOpcode::kConstant);
1865 }
1866
1867 constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl(
1868 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false))) {
1869 return AppendImpl(
1870 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/false));
1871 }
1872
1873 // This does not check that T has the same type as the instruction, so e.g.
1874 // IsConstantScalar(1.0) may match a constant of shape int32[].
1875 template <typename ScalarTy>
1876 constexpr auto IsConstantScalar(const ScalarTy& val) const
1877 -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>(
1878 val, /*match_effective_scalar=*/false))) {
1879 return AppendImpl(
1880 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/false));
1881 }
1882
1883 constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl(
1884 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true))) {
1885 return AppendImpl(
1886 HloConstantScalarImpl</*Dummy*/ int>(/*match_effective_scalar=*/true));
1887 }
1888
1889 template <typename ScalarTy>
1890 constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const
1891 -> decltype(this->AppendImpl(HloConstantScalarImpl<ScalarTy>(
1892 val, /*match_effective_scalar=*/true))) {
1893 return AppendImpl(
1894 HloConstantScalarImpl<ScalarTy>(val, /*match_effective_scalar=*/true));
1895 }
1896
1897 // Modifies the pattern to match only if the instruction is not a constant.
1898 constexpr auto IsNonConstant() const
1899 -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) {
1900 return WithoutOpcode(HloOpcode::kConstant);
1901 }
1902
1903 // Modifies the pattern to match only if the instruction has a shape that
1904 // matches the given pattern.
1905 template <typename ShapeType, typename ShapeImpl>
1906 constexpr auto WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape)
1907 const -> decltype(this->AppendImpl(
1908 HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape))) {
1909 return AppendImpl(
1910 HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
1911 }
1912
1913 // Make this a templated function to work around gcc 4.9.4 template infinite
1914 // recursion bug.
1915 template <typename Dummy = void>
1916 constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const
1917 -> decltype(this->WithShape(Shape().EqualTo(shape))) {
1918 return WithShape(Shape().EqualTo(shape));
1919 }
1920
1921 // Make this a templated function to work around gcc 4.9.4 template infinite
1922 // recursion bug.
1923 template <typename Dummy = void>
1924 constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const
1925 -> decltype(this->WithShape(Shape().CompatibleTo(shape))) {
1926 return WithShape(Shape().CompatibleTo(shape));
1927 }
1928
1929 // Modifies the pattern to match only if the instruction has an operand that
1930 // matches the given pattern.
1931 template <typename OperandType, typename OperandImpl>
1932 constexpr auto WithOperand(
1933 int64 operand_index,
1934 const HloInstructionPattern<OperandType, OperandImpl>& operand) const
1935 -> decltype(this->AppendImpl(
1936 HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
1937 operand_index, operand))) {
1938 return AppendImpl(
1939 HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
1940 operand_index, operand));
1941 }
1942
1943 template <typename OperandType1, typename OperandImpl1, typename OperandType2,
1944 typename OperandImpl2>
1945 constexpr auto WithBinaryOperandsAnyOrder(
1946 const HloInstructionPattern<OperandType1, OperandImpl1>& op1,
1947 const HloInstructionPattern<OperandType2, OperandImpl2>& op2) const
1948 -> decltype(this->AppendImpl(
1949 HloInstructionPatternBinaryOperandsAnyOrderImpl<
1950 OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1,
1951 op2))) {
1952 return AppendImpl(
1953 HloInstructionPatternBinaryOperandsAnyOrderImpl<
1954 OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2));
1955 }
1956
1957 // Modifies the pattern to match only if the instruction is a fusion node with
1958 // the given kind.
1959 constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const
1960 -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) {
1961 return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
1962 }
1963
1964 // Modifies the pattern to match only if the instruction is a
1965 // get-tuple-element with the given tuple index.
1966 constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype(
1967 this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) {
1968 return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
1969 }
1970
1971 // Modifies the pattern to match only if the instruction is a parameter
1972 // with the given parameter number.
1973 constexpr auto WithParameterNum(int64 parameter_num) const -> decltype(
1974 this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) {
1975 return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num));
1976 }
1977
1978 // Modifies the pattern to match if the instruction is used exactly once.
1979 // Does not match if the instruction is used twice by the same user (e.g.
1980 // multiply(x,x)).
1981 constexpr auto WithOneUse() const
1982 -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) {
1983 return AppendImpl(HloInstructionPatternOneUseImpl());
1984 }
1985
1986 // Modifies the pattern to match if the instruction is used by exactly one
1987 // other instruction. Will match if the instruction is used twice, so long as
1988 // it's by the same user (e.g. multiply(x,x)).
1989 constexpr auto WithOneUser() const
1990 -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) {
1991 return AppendImpl(HloInstructionPatternOneUserImpl());
1992 }
1993
1994 // Modifies the pattern to match only if the instruction has the given
1995 // comparison direction.
1996 auto WithComparisonDirection(ComparisonDirection direction) const
1997 -> decltype(this->AppendImpl(
1998 HloInstructionPatternComparisonDirectionImpl(direction))) {
1999 return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction));
2000 }
2001
2002 void DescribeTo(std::ostream* os, int64 indent = 0) const {
2003 impl_.DescribeTo(os, indent);
2004 }
2005
2006 private:
2007 Impl impl_;
2008 HloInstructionType** matched_inst_;
2009 };
2010
2011 } // namespace detail
2012
2013 // Creates an instruction pattern that will capture the matched instruction in
2014 // the argument.
2015 inline constexpr detail::HloInstructionPattern<
2016 const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
2017 Op(const ::xla::HloInstruction** matched_inst = nullptr) {
2018 return detail::HloInstructionPattern<const ::xla::HloInstruction,
2019 detail::HloInstructionPatternBaseImpl>(
2020 detail::HloInstructionPatternBaseImpl(), matched_inst);
2021 }
2022
2023 // Creates an instruction pattern that will capture the matched instruction in
2024 // the argument.
2025 inline constexpr detail::HloInstructionPattern<
2026 ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
2027 Op(::xla::HloInstruction** matched_inst) {
2028 return detail::HloInstructionPattern<::xla::HloInstruction,
2029 detail::HloInstructionPatternBaseImpl>(
2030 detail::HloInstructionPatternBaseImpl(), matched_inst);
2031 }
2032
2033 // Helpers for nullary instructions.
2034 #define XLA_NULLOP_PATTERN(NAME) \
2035 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
2036 return Op().WithOpcode(HloOpcode::k##NAME); \
2037 } \
2038 \
2039 template <typename HloInstructionType> \
2040 inline auto NAME(HloInstructionType** matched_inst) \
2041 ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) { \
2042 return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \
2043 }
2044 XLA_NULLOP_PATTERN(Constant)
2045 XLA_NULLOP_PATTERN(Parameter)
2046 XLA_NULLOP_PATTERN(Iota)
2047 XLA_NULLOP_PATTERN(Rng)
2048 #undef XLA_NULLOP_PATTERN
2049
2050 // Helpers for unary instructions.
2051 #define XLA_UNOP_PATTERN(NAME) \
2052 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
2053 return Op().WithOpcode(HloOpcode::k##NAME); \
2054 } \
2055 \
2056 template <typename Arg> \
2057 inline auto NAME(Arg&& arg)->decltype( \
2058 Op().WithOpcode(HloOpcode::k##NAME) \
2059 .WithOperand(0, std::forward<Arg>(arg))) { \
2060 return Op() \
2061 .WithOpcode(HloOpcode::k##NAME) \
2062 .WithOperand(0, std::forward<Arg>(arg)); \
2063 } \
2064 \
2065 template <typename HloInstructionType, typename Arg> \
2066 inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) \
2067 ->decltype(Op(matched_inst) \
2068 .WithOpcode(HloOpcode::k##NAME) \
2069 .WithOperand(0, std::forward<Arg>(arg))) { \
2070 return Op(matched_inst) \
2071 .WithOpcode(HloOpcode::k##NAME) \
2072 .WithOperand(0, std::forward<Arg>(arg)); \
2073 }
2074 XLA_UNOP_PATTERN(Abs)
2075 XLA_UNOP_PATTERN(RoundNearestAfz)
2076 XLA_UNOP_PATTERN(Bitcast)
2077 XLA_UNOP_PATTERN(Broadcast)
2078 XLA_UNOP_PATTERN(Ceil)
2079 XLA_UNOP_PATTERN(Convert)
2080 XLA_UNOP_PATTERN(Copy)
2081 XLA_UNOP_PATTERN(Cos)
2082 XLA_UNOP_PATTERN(AllReduce)
2083 XLA_UNOP_PATTERN(Exp)
2084 XLA_UNOP_PATTERN(Fft)
2085 XLA_UNOP_PATTERN(Floor)
2086 XLA_UNOP_PATTERN(GetTupleElement)
2087 XLA_UNOP_PATTERN(Imag)
2088 XLA_UNOP_PATTERN(Infeed)
2089 XLA_UNOP_PATTERN(IsFinite)
2090 XLA_UNOP_PATTERN(Log)
2091 XLA_UNOP_PATTERN(Not)
2092 XLA_UNOP_PATTERN(Negate)
2093 XLA_UNOP_PATTERN(Real)
2094 XLA_UNOP_PATTERN(Recv)
2095 XLA_UNOP_PATTERN(RecvDone)
2096 XLA_UNOP_PATTERN(ReducePrecision)
2097 XLA_UNOP_PATTERN(Reshape)
2098 XLA_UNOP_PATTERN(Reverse)
2099 XLA_UNOP_PATTERN(Rsqrt)
2100 XLA_UNOP_PATTERN(SendDone)
2101 XLA_UNOP_PATTERN(Sign)
2102 XLA_UNOP_PATTERN(Sin)
2103 XLA_UNOP_PATTERN(Slice)
2104 XLA_UNOP_PATTERN(Sqrt)
2105 XLA_UNOP_PATTERN(Tanh)
2106 XLA_UNOP_PATTERN(Transpose)
2107 #undef XLA_UNOP_PATTERN
2108
2109 // Helpers for binary instructions.
2110 #define XLA_BINOP_PATTERN(NAME) \
2111 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
2112 return Op().WithOpcode(HloOpcode::k##NAME); \
2113 } \
2114 \
2115 template <typename Lhs, typename Rhs> \
2116 inline auto NAME(Lhs&& lhs, Rhs&& rhs) \
2117 ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \
2118 .WithOperand(0, std::forward<Lhs>(lhs)) \
2119 .WithOperand(1, std::forward<Rhs>(rhs))) { \
2120 return Op() \
2121 .WithOpcode(HloOpcode::k##NAME) \
2122 .WithOperand(0, std::forward<Lhs>(lhs)) \
2123 .WithOperand(1, std::forward<Rhs>(rhs)); \
2124 } \
2125 \
2126 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2127 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
2128 ->decltype(Op(matched_inst) \
2129 .WithOpcode(HloOpcode::k##NAME) \
2130 .WithOperand(0, std::forward<Lhs>(lhs)) \
2131 .WithOperand(1, std::forward<Rhs>(rhs))) { \
2132 return Op(matched_inst) \
2133 .WithOpcode(HloOpcode::k##NAME) \
2134 .WithOperand(0, std::forward<Lhs>(lhs)) \
2135 .WithOperand(1, std::forward<Rhs>(rhs)); \
2136 }
2137
2138 #define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
2139 XLA_BINOP_PATTERN(NAME) \
2140 \
2141 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2142 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2143 Rhs&& rhs) \
2144 ->decltype(Op(matched_inst) \
2145 .WithOpcode(HloOpcode::k##NAME) \
2146 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2147 std::forward<Rhs>(rhs))) { \
2148 return Op(matched_inst) \
2149 .WithOpcode(HloOpcode::k##NAME) \
2150 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2151 std::forward<Rhs>(rhs)); \
2152 } \
2153 template <typename Lhs, typename Rhs> \
2154 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
2155 ->decltype(NAME##AnyOrder<const HloInstruction>( \
2156 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) { \
2157 return NAME##AnyOrder<const HloInstruction>( \
2158 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
2159 }
2160 XLA_COMMUTATIVE_BINOP_PATTERN(Add)
2161 XLA_BINOP_PATTERN(Atan2)
2162 XLA_BINOP_PATTERN(Divide)
2163 XLA_BINOP_PATTERN(Complex)
2164 XLA_BINOP_PATTERN(Compare)
2165 XLA_BINOP_PATTERN(Convolution)
2166 XLA_BINOP_PATTERN(Dot)
2167 XLA_BINOP_PATTERN(Gather)
2168 XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
2169 XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
2170 XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
2171 XLA_BINOP_PATTERN(Outfeed)
2172 XLA_BINOP_PATTERN(Pad)
2173 XLA_BINOP_PATTERN(Power)
2174 XLA_BINOP_PATTERN(ReduceWindow)
2175 XLA_BINOP_PATTERN(Remainder)
2176 XLA_BINOP_PATTERN(Send)
2177 XLA_BINOP_PATTERN(Subtract)
2178 XLA_COMMUTATIVE_BINOP_PATTERN(And)
2179 XLA_COMMUTATIVE_BINOP_PATTERN(Or)
2180 XLA_BINOP_PATTERN(ShiftLeft)
2181 XLA_BINOP_PATTERN(ShiftRightArithmetic)
2182 XLA_BINOP_PATTERN(ShiftRightLogical)
2183 #undef XLA_COMMUTATIVE_BINOP_PATTERN
2184 #undef XLA_BINOP_PATTERN
2185
2186 // Helpers for ternary instructions.
2187 #define XLA_TERNOP_PATTERN(NAME) \
2188 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
2189 return Op().WithOpcode(HloOpcode::k##NAME); \
2190 } \
2191 \
2192 template <typename Arg0, typename Arg1, typename Arg2> \
2193 inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) \
2194 ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \
2195 .WithOperand(0, std::forward<Arg0>(arg0)) \
2196 .WithOperand(1, std::forward<Arg1>(arg1)) \
2197 .WithOperand(2, std::forward<Arg2>(arg2))) { \
2198 return Op() \
2199 .WithOpcode(HloOpcode::k##NAME) \
2200 .WithOperand(0, std::forward<Arg0>(arg0)) \
2201 .WithOperand(1, std::forward<Arg1>(arg1)) \
2202 .WithOperand(2, std::forward<Arg2>(arg2)); \
2203 } \
2204 \
2205 template <typename HloInstructionType, typename Arg0, typename Arg1, \
2206 typename Arg2> \
2207 inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \
2208 Arg1&& arg1, Arg2&& arg2) \
2209 ->decltype(Op(matched_inst) \
2210 .WithOpcode(HloOpcode::k##NAME) \
2211 .WithOperand(0, std::forward<Arg0>(arg0)) \
2212 .WithOperand(1, std::forward<Arg1>(arg1)) \
2213 .WithOperand(2, std::forward<Arg2>(arg2))) { \
2214 return Op(matched_inst) \
2215 .WithOpcode(HloOpcode::k##NAME) \
2216 .WithOperand(0, std::forward<Arg0>(arg0)) \
2217 .WithOperand(1, std::forward<Arg1>(arg1)) \
2218 .WithOperand(2, std::forward<Arg2>(arg2)); \
2219 }
2220 XLA_TERNOP_PATTERN(Clamp);
2221 XLA_TERNOP_PATTERN(Scatter);
2222 XLA_TERNOP_PATTERN(Select);
2223 #undef XLA_TERNOP_PATTERN
2224
2225 namespace detail {
2226 template <typename Matcher, typename FirstArg>
2227 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg)
2228 -> decltype(m.WithOperand(operand_num, std::forward<FirstArg>(first_arg))) {
2229 return m.WithOperand(operand_num, std::forward<FirstArg>(first_arg));
2230 }
2231
2232 template <typename Matcher, typename FirstArg, typename... Args>
2233 inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg,
2234 Args&&... args)
2235 -> decltype(WithOperands(m.WithOperand(operand_num,
2236 std::forward<FirstArg>(first_arg)),
2237 operand_num + 1, std::forward<Args>(args)...)) {
2238 return WithOperands(
2239 m.WithOperand(operand_num, std::forward<FirstArg>(first_arg)),
2240 operand_num + 1, std::forward<Args>(args)...);
2241 }
2242 } // namespace detail
2243
2244 #define XLA_VARIADIC_OP_PATTERN(NAME) \
2245 inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
2246 return Op().WithOpcode(HloOpcode::k##NAME); \
2247 } \
2248 \
2249 template <typename... Args> \
2250 inline auto NAME(Args&&... args) \
2251 ->decltype(detail::WithOperands(Op().WithOpcode(HloOpcode::k##NAME) \
2252 .WithNumOperands(sizeof...(Args)), \
2253 0, std::forward<Args>(args)...)) { \
2254 return detail::WithOperands( \
2255 Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \
2256 /*operand_num=*/0, std::forward<Args>(args)...); \
2257 } \
2258 \
2259 template <typename HloInstructionType, typename... Args> \
2260 inline auto NAME(HloInstructionType** matched_inst, Args&&... args) \
2261 ->decltype(detail::WithOperands(Op(matched_inst) \
2262 .WithOpcode(HloOpcode::k##NAME) \
2263 .WithNumOperands(sizeof...(Args)), \
2264 0, std::forward<Args>(args)...)) { \
2265 return detail::WithOperands(Op(matched_inst) \
2266 .WithOpcode(HloOpcode::k##NAME) \
2267 .WithNumOperands(sizeof...(Args)), \
2268 /*operand_num=*/0, \
2269 std::forward<Args>(args)...); \
2270 }
2271
2272 // We could implement all ops as "variadic" ops, but it would make the
2273 // already-bad compile errors even worse.
2274 XLA_VARIADIC_OP_PATTERN(AfterAll);
2275 XLA_VARIADIC_OP_PATTERN(Concatenate);
2276 XLA_VARIADIC_OP_PATTERN(CustomCall);
2277 XLA_VARIADIC_OP_PATTERN(DynamicSlice)
2278 XLA_VARIADIC_OP_PATTERN(Map)
2279 XLA_VARIADIC_OP_PATTERN(Reduce);
2280 XLA_VARIADIC_OP_PATTERN(Sort);
2281 XLA_VARIADIC_OP_PATTERN(Tuple);
2282
2283 // Helpers for comparison instructions.
2284 #define XLA_COMPARE_PATTERN(NAME) \
2285 inline auto NAME()->decltype( \
2286 Op().WithOpcode(HloOpcode::kCompare) \
2287 .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
2288 return Op() \
2289 .WithOpcode(HloOpcode::kCompare) \
2290 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2291 } \
2292 \
2293 template <typename Lhs, typename Rhs> \
2294 inline auto NAME(Lhs&& lhs, Rhs&& rhs) \
2295 ->decltype(Op().WithOpcode(HloOpcode::kCompare) \
2296 .WithOperand(0, std::forward<Lhs>(lhs)) \
2297 .WithOperand(1, std::forward<Rhs>(rhs)) \
2298 .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
2299 return Op() \
2300 .WithOpcode(HloOpcode::kCompare) \
2301 .WithOperand(0, std::forward<Lhs>(lhs)) \
2302 .WithOperand(1, std::forward<Rhs>(rhs)) \
2303 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2304 } \
2305 \
2306 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2307 inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
2308 ->decltype(Op(matched_inst) \
2309 .WithOpcode(HloOpcode::kCompare) \
2310 .WithOperand(0, std::forward<Lhs>(lhs)) \
2311 .WithOperand(1, std::forward<Rhs>(rhs)) \
2312 .WithComparisonDirection(ComparisonDirection::k##NAME)) { \
2313 return Op(matched_inst) \
2314 .WithOpcode(HloOpcode::kCompare) \
2315 .WithOperand(0, std::forward<Lhs>(lhs)) \
2316 .WithOperand(1, std::forward<Rhs>(rhs)) \
2317 .WithComparisonDirection(ComparisonDirection::k##NAME); \
2318 }
2319
2320 #define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \
2321 XLA_COMPARE_PATTERN(NAME) \
2322 \
2323 template <typename HloInstructionType, typename Lhs, typename Rhs> \
2324 inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
2325 Rhs&& rhs) \
2326 ->decltype(Op(matched_inst) \
2327 .WithOpcode(HloOpcode::kCompare) \
2328 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2329 std::forward<Rhs>(rhs))) { \
2330 return Op(matched_inst) \
2331 .WithOpcode(HloOpcode::kCompare) \
2332 .WithBinaryOperandsAnyOrder(std::forward<Lhs>(lhs), \
2333 std::forward<Rhs>(rhs)); \
2334 } \
2335 template <typename Lhs, typename Rhs> \
2336 inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
2337 ->decltype(NAME##AnyOrder<const HloInstruction>( \
2338 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs))) { \
2339 return NAME##AnyOrder<const HloInstruction>( \
2340 nullptr, std::forward<Lhs>(lhs), std::forward<Rhs>(rhs)); \
2341 }
2342
2343 XLA_COMMUTATIVE_COMPARE_PATTERN(Eq);
2344 XLA_COMMUTATIVE_COMPARE_PATTERN(Ne);
2345 XLA_COMPARE_PATTERN(Ge);
2346 XLA_COMPARE_PATTERN(Gt);
2347 XLA_COMPARE_PATTERN(Le);
2348 XLA_COMPARE_PATTERN(Lt);
2349
2350 // Helpers for matching non-constant instructions.
2351 inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
2352 return Op().IsNonConstant();
2353 }
2354
2355 template <typename HloInstructionType>
2356 inline auto NonConstant(HloInstructionType** matched_inst)
2357 -> decltype(Op(matched_inst).IsNonConstant()) {
2358 return Op(matched_inst).IsNonConstant();
2359 }
2360
2361 // Add overloads for GetTupleElement which take a int64 specifying which tuple
2362 // element is selected.
2363 template <typename Arg>
2364 inline auto GetTupleElement(Arg&& arg, int64 tuple_index)
2365 -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement)
2366 .WithOperand(0, std::forward<Arg>(arg))
2367 .WithTupleIndex(tuple_index)) {
2368 return Op()
2369 .WithOpcode(HloOpcode::kGetTupleElement)
2370 .WithOperand(0, std::forward<Arg>(arg))
2371 .WithTupleIndex(tuple_index);
2372 }
2373
2374 template <typename HloInstructionType, typename Arg>
2375 inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
2376 int64 tuple_index)
2377 -> decltype(Op(matched_inst)
2378 .WithOpcode(HloOpcode::kGetTupleElement)
2379 .WithOperand(0, std::forward<Arg>(arg))
2380 .WithTupleIndex(tuple_index)) {
2381 return Op(matched_inst)
2382 .WithOpcode(HloOpcode::kGetTupleElement)
2383 .WithOperand(0, std::forward<Arg>(arg))
2384 .WithTupleIndex(tuple_index);
2385 }
2386
2387 // Add overloads for Parameter which take an int64 specifying the parameter
2388 // number.
2389 inline auto Parameter(int64 parameter_num) -> decltype(
2390 Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) {
2391 return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num);
2392 }
2393 template <typename HloInstructionType>
2394 inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num)
2395 -> decltype(Op(matched_inst)
2396 .WithOpcode(HloOpcode::kParameter)
2397 .WithParameterNum(parameter_num)) {
2398 return Op(matched_inst)
2399 .WithOpcode(HloOpcode::kParameter)
2400 .WithParameterNum(parameter_num);
2401 }
2402
2403 inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) {
2404 return Op().IsConstantScalar();
2405 }
2406
2407 template <typename HloInstructionType>
2408 inline auto ConstantScalar(HloInstructionType** matched_inst)
2409 -> decltype(Op(matched_inst).IsConstantScalar()) {
2410 return Op(matched_inst).IsConstantScalar();
2411 }
2412
2413 template <typename ScalarTy>
2414 inline auto ConstantScalar(ScalarTy val)
2415 -> decltype(Op().IsConstantScalar(val)) {
2416 return Op().IsConstantScalar(val);
2417 }
2418
2419 template <typename HloInstructionType, typename ScalarTy>
2420 inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val)
2421 -> decltype(Op(matched_inst).IsConstantScalar(val)) {
2422 return Op(matched_inst).IsConstantScalar(val);
2423 }
2424
2425 inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) {
2426 return Op().IsConstantEffectiveScalar();
2427 }
2428
2429 template <typename HloInstructionType>
2430 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst)
2431 -> decltype(Op(matched_inst).IsConstantScalar()) {
2432 return Op(matched_inst).IsConstantEffectiveScalar();
2433 }
2434
2435 template <typename ScalarTy>
2436 inline auto ConstantEffectiveScalar(ScalarTy val)
2437 -> decltype(Op().IsConstantEffectiveScalar(val)) {
2438 return Op().IsConstantEffectiveScalar(val);
2439 }
2440
2441 template <typename HloInstructionType, typename ScalarTy>
2442 inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst,
2443 ScalarTy val)
2444 -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) {
2445 return Op(matched_inst).IsConstantEffectiveScalar(val);
2446 }
2447
2448 } // namespace match
2449
2450 } // namespace xla
2451
2452 #undef EXPLAIN
2453 #pragma pop_macro("EXPLAIN")
2454 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
2455