1 //===- unittest/AST/DeclMatcher.h - AST unit test support ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_CLANG_UNITTESTS_AST_DECLMATCHER_H
10 #define LLVM_CLANG_UNITTESTS_AST_DECLMATCHER_H
11 
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 
14 namespace clang {
15 namespace ast_matchers {
16 
17 enum class DeclMatcherKind { First, Last };
18 
19 // Matcher class to retrieve the first/last matched node under a given AST.
20 template <typename NodeType, DeclMatcherKind MatcherKind>
21 class DeclMatcher : public MatchFinder::MatchCallback {
22   NodeType *Node = nullptr;
run(const MatchFinder::MatchResult & Result)23   void run(const MatchFinder::MatchResult &Result) override {
24     if ((MatcherKind == DeclMatcherKind::First && Node == nullptr) ||
25         MatcherKind == DeclMatcherKind::Last) {
26       Node = const_cast<NodeType *>(Result.Nodes.getNodeAs<NodeType>(""));
27     }
28   }
29 public:
30   // Returns the first/last matched node under the tree rooted in `D`.
31   template <typename MatcherType>
match(const Decl * D,const MatcherType & AMatcher)32   NodeType *match(const Decl *D, const MatcherType &AMatcher) {
33     MatchFinder Finder;
34     Finder.addMatcher(AMatcher.bind(""), this);
35     Finder.matchAST(D->getASTContext());
36     assert(Node);
37     return Node;
38   }
39 };
40 template <typename NodeType>
41 using LastDeclMatcher = DeclMatcher<NodeType, DeclMatcherKind::Last>;
42 template <typename NodeType>
43 using FirstDeclMatcher = DeclMatcher<NodeType, DeclMatcherKind::First>;
44 
45 template <typename NodeType>
46 class DeclCounterWithPredicate : public MatchFinder::MatchCallback {
47   using UnaryPredicate = std::function<bool(const NodeType *)>;
48   UnaryPredicate Predicate;
49   unsigned Count = 0;
run(const MatchFinder::MatchResult & Result)50   void run(const MatchFinder::MatchResult &Result) override {
51     if (auto N = Result.Nodes.getNodeAs<NodeType>("")) {
52       if (Predicate(N))
53         ++Count;
54     }
55   }
56 
57 public:
DeclCounterWithPredicate()58   DeclCounterWithPredicate()
59       : Predicate([](const NodeType *) { return true; }) {}
DeclCounterWithPredicate(UnaryPredicate P)60   DeclCounterWithPredicate(UnaryPredicate P) : Predicate(P) {}
61   // Returns the number of matched nodes which satisfy the predicate under the
62   // tree rooted in `D`.
63   template <typename MatcherType>
match(const Decl * D,const MatcherType & AMatcher)64   unsigned match(const Decl *D, const MatcherType &AMatcher) {
65     MatchFinder Finder;
66     Finder.addMatcher(AMatcher.bind(""), this);
67     Finder.matchAST(D->getASTContext());
68     return Count;
69   }
70 };
71 
72 template <typename NodeType>
73 using DeclCounter = DeclCounterWithPredicate<NodeType>;
74 
75 } // end namespace ast_matchers
76 } // end namespace clang
77 
78 #endif
79