1 //===--- TestVisitor.h ------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// \brief Defines utility templates for RecursiveASTVisitor related tests.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H
15 #define LLVM_CLANG_UNITTESTS_TOOLING_TESTVISITOR_H
16 
17 #include "clang/AST/ASTConsumer.h"
18 #include "clang/AST/ASTContext.h"
19 #include "clang/AST/RecursiveASTVisitor.h"
20 #include "clang/Frontend/CompilerInstance.h"
21 #include "clang/Frontend/FrontendAction.h"
22 #include "clang/Tooling/Tooling.h"
23 #include "gtest/gtest.h"
24 #include <vector>
25 
26 namespace clang {
27 
28 /// \brief Base class for simple RecursiveASTVisitor based tests.
29 ///
30 /// This is a drop-in replacement for RecursiveASTVisitor itself, with the
31 /// additional capability of running it over a snippet of code.
32 ///
33 /// Visits template instantiations and implicit code by default.
34 template <typename T>
35 class TestVisitor : public RecursiveASTVisitor<T> {
36 public:
TestVisitor()37   TestVisitor() { }
38 
~TestVisitor()39   virtual ~TestVisitor() { }
40 
41   enum Language {
42     Lang_C,
43     Lang_CXX98,
44     Lang_CXX11,
45     Lang_CXX14,
46     Lang_CXX17,
47     Lang_CXX2a,
48     Lang_OBJC,
49     Lang_OBJCXX11,
50     Lang_CXX = Lang_CXX98
51   };
52 
53   /// \brief Runs the current AST visitor over the given code.
54   bool runOver(StringRef Code, Language L = Lang_CXX) {
55     std::vector<std::string> Args;
56     switch (L) {
57       case Lang_C:
58         Args.push_back("-x");
59         Args.push_back("c");
60         break;
61       case Lang_CXX98: Args.push_back("-std=c++98"); break;
62       case Lang_CXX11: Args.push_back("-std=c++11"); break;
63       case Lang_CXX14: Args.push_back("-std=c++14"); break;
64       case Lang_CXX17: Args.push_back("-std=c++17"); break;
65       case Lang_CXX2a: Args.push_back("-std=c++2a"); break;
66       case Lang_OBJC:
67         Args.push_back("-ObjC");
68         Args.push_back("-fobjc-runtime=macosx-10.12.0");
69         break;
70       case Lang_OBJCXX11:
71         Args.push_back("-ObjC++");
72         Args.push_back("-std=c++11");
73         Args.push_back("-fblocks");
74         break;
75     }
76     return tooling::runToolOnCodeWithArgs(CreateTestAction(), Code, Args);
77   }
78 
shouldVisitTemplateInstantiations()79   bool shouldVisitTemplateInstantiations() const {
80     return true;
81   }
82 
shouldVisitImplicitCode()83   bool shouldVisitImplicitCode() const {
84     return true;
85   }
86 
87 protected:
CreateTestAction()88   virtual std::unique_ptr<ASTFrontendAction> CreateTestAction() {
89     return std::make_unique<TestAction>(this);
90   }
91 
92   class FindConsumer : public ASTConsumer {
93   public:
FindConsumer(TestVisitor * Visitor)94     FindConsumer(TestVisitor *Visitor) : Visitor(Visitor) {}
95 
HandleTranslationUnit(clang::ASTContext & Context)96     void HandleTranslationUnit(clang::ASTContext &Context) override {
97       Visitor->Context = &Context;
98       Visitor->TraverseDecl(Context.getTranslationUnitDecl());
99     }
100 
101   private:
102     TestVisitor *Visitor;
103   };
104 
105   class TestAction : public ASTFrontendAction {
106   public:
TestAction(TestVisitor * Visitor)107     TestAction(TestVisitor *Visitor) : Visitor(Visitor) {}
108 
109     std::unique_ptr<clang::ASTConsumer>
CreateASTConsumer(CompilerInstance &,llvm::StringRef dummy)110     CreateASTConsumer(CompilerInstance &, llvm::StringRef dummy) override {
111       /// TestConsumer will be deleted by the framework calling us.
112       return std::make_unique<FindConsumer>(Visitor);
113     }
114 
115   protected:
116     TestVisitor *Visitor;
117   };
118 
119   ASTContext *Context;
120 };
121 
122 /// \brief A RecursiveASTVisitor to check that certain matches are (or are
123 /// not) observed during visitation.
124 ///
125 /// This is a RecursiveASTVisitor for testing the RecursiveASTVisitor itself,
126 /// and allows simple creation of test visitors running matches on only a small
127 /// subset of the Visit* methods.
128 template <typename T, template <typename> class Visitor = TestVisitor>
129 class ExpectedLocationVisitor : public Visitor<T> {
130 public:
131   /// \brief Expect 'Match' *not* to occur at the given 'Line' and 'Column'.
132   ///
133   /// Any number of matches can be disallowed.
DisallowMatch(Twine Match,unsigned Line,unsigned Column)134   void DisallowMatch(Twine Match, unsigned Line, unsigned Column) {
135     DisallowedMatches.push_back(MatchCandidate(Match, Line, Column));
136   }
137 
138   /// \brief Expect 'Match' to occur at the given 'Line' and 'Column'.
139   ///
140   /// Any number of expected matches can be set by calling this repeatedly.
141   /// Each is expected to be matched 'Times' number of times. (This is useful in
142   /// cases in which different AST nodes can match at the same source code
143   /// location.)
144   void ExpectMatch(Twine Match, unsigned Line, unsigned Column,
145                    unsigned Times = 1) {
146     ExpectedMatches.push_back(ExpectedMatch(Match, Line, Column, Times));
147   }
148 
149   /// \brief Checks that all expected matches have been found.
~ExpectedLocationVisitor()150   ~ExpectedLocationVisitor() override {
151     for (typename std::vector<ExpectedMatch>::const_iterator
152              It = ExpectedMatches.begin(), End = ExpectedMatches.end();
153          It != End; ++It) {
154       It->ExpectFound();
155     }
156   }
157 
158 protected:
159   /// \brief Checks an actual match against expected and disallowed matches.
160   ///
161   /// Implementations are required to call this with appropriate values
162   /// for 'Name' during visitation.
Match(StringRef Name,SourceLocation Location)163   void Match(StringRef Name, SourceLocation Location) {
164     const FullSourceLoc FullLocation = this->Context->getFullLoc(Location);
165 
166     for (typename std::vector<MatchCandidate>::const_iterator
167              It = DisallowedMatches.begin(), End = DisallowedMatches.end();
168          It != End; ++It) {
169       EXPECT_FALSE(It->Matches(Name, FullLocation))
170           << "Matched disallowed " << *It;
171     }
172 
173     for (typename std::vector<ExpectedMatch>::iterator
174              It = ExpectedMatches.begin(), End = ExpectedMatches.end();
175          It != End; ++It) {
176       It->UpdateFor(Name, FullLocation, this->Context->getSourceManager());
177     }
178   }
179 
180  private:
181   struct MatchCandidate {
182     std::string ExpectedName;
183     unsigned LineNumber;
184     unsigned ColumnNumber;
185 
MatchCandidateMatchCandidate186     MatchCandidate(Twine Name, unsigned LineNumber, unsigned ColumnNumber)
187       : ExpectedName(Name.str()), LineNumber(LineNumber),
188         ColumnNumber(ColumnNumber) {
189     }
190 
MatchesMatchCandidate191     bool Matches(StringRef Name, FullSourceLoc const &Location) const {
192       return MatchesName(Name) && MatchesLocation(Location);
193     }
194 
PartiallyMatchesMatchCandidate195     bool PartiallyMatches(StringRef Name, FullSourceLoc const &Location) const {
196       return MatchesName(Name) || MatchesLocation(Location);
197     }
198 
MatchesNameMatchCandidate199     bool MatchesName(StringRef Name) const {
200       return Name == ExpectedName;
201     }
202 
MatchesLocationMatchCandidate203     bool MatchesLocation(FullSourceLoc const &Location) const {
204       return Location.isValid() &&
205           Location.getSpellingLineNumber() == LineNumber &&
206           Location.getSpellingColumnNumber() == ColumnNumber;
207     }
208 
209     friend std::ostream &operator<<(std::ostream &Stream,
210                                     MatchCandidate const &Match) {
211       return Stream << Match.ExpectedName
212                     << " at " << Match.LineNumber << ":" << Match.ColumnNumber;
213     }
214   };
215 
216   struct ExpectedMatch {
ExpectedMatchExpectedMatch217     ExpectedMatch(Twine Name, unsigned LineNumber, unsigned ColumnNumber,
218                   unsigned Times)
219         : Candidate(Name, LineNumber, ColumnNumber), TimesExpected(Times),
220           TimesSeen(0) {}
221 
UpdateForExpectedMatch222     void UpdateFor(StringRef Name, FullSourceLoc Location, SourceManager &SM) {
223       if (Candidate.Matches(Name, Location)) {
224         EXPECT_LT(TimesSeen, TimesExpected);
225         ++TimesSeen;
226       } else if (TimesSeen < TimesExpected &&
227                  Candidate.PartiallyMatches(Name, Location)) {
228         llvm::raw_string_ostream Stream(PartialMatches);
229         Stream << ", partial match: \"" << Name << "\" at ";
230         Location.print(Stream, SM);
231       }
232     }
233 
ExpectFoundExpectedMatch234     void ExpectFound() const {
235       EXPECT_EQ(TimesExpected, TimesSeen)
236           << "Expected \"" << Candidate.ExpectedName
237           << "\" at " << Candidate.LineNumber
238           << ":" << Candidate.ColumnNumber << PartialMatches;
239     }
240 
241     MatchCandidate Candidate;
242     std::string PartialMatches;
243     unsigned TimesExpected;
244     unsigned TimesSeen;
245   };
246 
247   std::vector<MatchCandidate> DisallowedMatches;
248   std::vector<ExpectedMatch> ExpectedMatches;
249 };
250 }
251 
252 #endif
253