1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //  Provides MatchVerifier, a base class to implement gtest matchers that
11 //  verify things that can be matched on the AST.
12 //
13 //  Also implements matchers based on MatchVerifier:
14 //  LocationVerifier and RangeVerifier to verify whether a matched node has
15 //  the expected source location or source range.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
20 #define LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
21 
22 #include "clang/AST/ASTContext.h"
23 #include "clang/ASTMatchers/ASTMatchFinder.h"
24 #include "clang/ASTMatchers/ASTMatchers.h"
25 #include "clang/Tooling/Tooling.h"
26 #include "gtest/gtest.h"
27 
28 namespace clang {
29 namespace ast_matchers {
30 
31 enum Language {
32     Lang_C,
33     Lang_C89,
34     Lang_CXX,
35     Lang_CXX11,
36     Lang_OpenCL,
37     Lang_OBJCXX
38 };
39 
40 /// \brief Base class for verifying some property of nodes found by a matcher.
41 template <typename NodeType>
42 class MatchVerifier : public MatchFinder::MatchCallback {
43 public:
44   template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher)45   testing::AssertionResult match(const std::string &Code,
46                                  const MatcherType &AMatcher) {
47     std::vector<std::string> Args;
48     return match(Code, AMatcher, Args, Lang_CXX);
49   }
50 
51   template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher,Language L)52   testing::AssertionResult match(const std::string &Code,
53                                  const MatcherType &AMatcher,
54                                  Language L) {
55     std::vector<std::string> Args;
56     return match(Code, AMatcher, Args, L);
57   }
58 
59   template <typename MatcherType>
60   testing::AssertionResult match(const std::string &Code,
61                                  const MatcherType &AMatcher,
62                                  std::vector<std::string>& Args,
63                                  Language L);
64 
65 protected:
66   void run(const MatchFinder::MatchResult &Result) override;
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)67   virtual void verify(const MatchFinder::MatchResult &Result,
68                       const NodeType &Node) {}
69 
setFailure(const Twine & Result)70   void setFailure(const Twine &Result) {
71     Verified = false;
72     VerifyResult = Result.str();
73   }
74 
setSuccess()75   void setSuccess() {
76     Verified = true;
77   }
78 
79 private:
80   bool Verified;
81   std::string VerifyResult;
82 };
83 
84 /// \brief Runs a matcher over some code, and returns the result of the
85 /// verifier for the matched node.
86 template <typename NodeType> template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher,std::vector<std::string> & Args,Language L)87 testing::AssertionResult MatchVerifier<NodeType>::match(
88     const std::string &Code, const MatcherType &AMatcher,
89     std::vector<std::string>& Args, Language L) {
90   MatchFinder Finder;
91   Finder.addMatcher(AMatcher.bind(""), this);
92   std::unique_ptr<tooling::FrontendActionFactory> Factory(
93       tooling::newFrontendActionFactory(&Finder));
94 
95   StringRef FileName;
96   switch (L) {
97   case Lang_C:
98     Args.push_back("-std=c99");
99     FileName = "input.c";
100     break;
101   case Lang_C89:
102     Args.push_back("-std=c89");
103     FileName = "input.c";
104     break;
105   case Lang_CXX:
106     Args.push_back("-std=c++98");
107     FileName = "input.cc";
108     break;
109   case Lang_CXX11:
110     Args.push_back("-std=c++11");
111     FileName = "input.cc";
112     break;
113   case Lang_OpenCL:
114     FileName = "input.cl";
115     break;
116   case Lang_OBJCXX:
117     FileName = "input.mm";
118     break;
119   }
120 
121   // Default to failure in case callback is never called
122   setFailure("Could not find match");
123   if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
124     return testing::AssertionFailure() << "Parsing error";
125   if (!Verified)
126     return testing::AssertionFailure() << VerifyResult;
127   return testing::AssertionSuccess();
128 }
129 
130 template <typename NodeType>
run(const MatchFinder::MatchResult & Result)131 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
132   const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
133   if (!Node) {
134     setFailure("Matched node has wrong type");
135   } else {
136     // Callback has been called, default to success.
137     setSuccess();
138     verify(Result, *Node);
139   }
140 }
141 
142 template <>
run(const MatchFinder::MatchResult & Result)143 inline void MatchVerifier<ast_type_traits::DynTypedNode>::run(
144     const MatchFinder::MatchResult &Result) {
145   BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
146   BoundNodes::IDToNodeMap::const_iterator I = M.find("");
147   if (I == M.end()) {
148     setFailure("Node was not bound");
149   } else {
150     // Callback has been called, default to success.
151     setSuccess();
152     verify(Result, I->second);
153   }
154 }
155 
156 /// \brief Verify whether a node has the correct source location.
157 ///
158 /// By default, Node.getSourceLocation() is checked. This can be changed
159 /// by overriding getLocation().
160 template <typename NodeType>
161 class LocationVerifier : public MatchVerifier<NodeType> {
162 public:
expectLocation(unsigned Line,unsigned Column)163   void expectLocation(unsigned Line, unsigned Column) {
164     ExpectLine = Line;
165     ExpectColumn = Column;
166   }
167 
168 protected:
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)169   void verify(const MatchFinder::MatchResult &Result,
170               const NodeType &Node) override {
171     SourceLocation Loc = getLocation(Node);
172     unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
173     unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
174     if (Line != ExpectLine || Column != ExpectColumn) {
175       std::string MsgStr;
176       llvm::raw_string_ostream Msg(MsgStr);
177       Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
178           << ">, found <";
179       Loc.print(Msg, *Result.SourceManager);
180       Msg << '>';
181       this->setFailure(Msg.str());
182     }
183   }
184 
getLocation(const NodeType & Node)185   virtual SourceLocation getLocation(const NodeType &Node) {
186     return Node.getLocation();
187   }
188 
189 private:
190   unsigned ExpectLine, ExpectColumn;
191 };
192 
193 /// \brief Verify whether a node has the correct source range.
194 ///
195 /// By default, Node.getSourceRange() is checked. This can be changed
196 /// by overriding getRange().
197 template <typename NodeType>
198 class RangeVerifier : public MatchVerifier<NodeType> {
199 public:
expectRange(unsigned BeginLine,unsigned BeginColumn,unsigned EndLine,unsigned EndColumn)200   void expectRange(unsigned BeginLine, unsigned BeginColumn,
201                    unsigned EndLine, unsigned EndColumn) {
202     ExpectBeginLine = BeginLine;
203     ExpectBeginColumn = BeginColumn;
204     ExpectEndLine = EndLine;
205     ExpectEndColumn = EndColumn;
206   }
207 
208 protected:
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)209   void verify(const MatchFinder::MatchResult &Result,
210               const NodeType &Node) override {
211     SourceRange R = getRange(Node);
212     SourceLocation Begin = R.getBegin();
213     SourceLocation End = R.getEnd();
214     unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
215     unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
216     unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
217     unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
218     if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
219         EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
220       std::string MsgStr;
221       llvm::raw_string_ostream Msg(MsgStr);
222       Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
223           << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
224       Begin.print(Msg, *Result.SourceManager);
225       Msg << '-';
226       End.print(Msg, *Result.SourceManager);
227       Msg << '>';
228       this->setFailure(Msg.str());
229     }
230   }
231 
getRange(const NodeType & Node)232   virtual SourceRange getRange(const NodeType &Node) {
233     return Node.getSourceRange();
234   }
235 
236 private:
237   unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
238 };
239 
240 /// \brief Verify whether a node's dump contains a given substring.
241 class DumpVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
242 public:
expectSubstring(const std::string & Str)243   void expectSubstring(const std::string &Str) {
244     ExpectSubstring = Str;
245   }
246 
247 protected:
verify(const MatchFinder::MatchResult & Result,const ast_type_traits::DynTypedNode & Node)248   void verify(const MatchFinder::MatchResult &Result,
249               const ast_type_traits::DynTypedNode &Node) override {
250     std::string DumpStr;
251     llvm::raw_string_ostream Dump(DumpStr);
252     Node.dump(Dump, *Result.SourceManager);
253 
254     if (Dump.str().find(ExpectSubstring) == std::string::npos) {
255       std::string MsgStr;
256       llvm::raw_string_ostream Msg(MsgStr);
257       Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
258           << Dump.str() << '>';
259       this->setFailure(Msg.str());
260     }
261   }
262 
263 private:
264   std::string ExpectSubstring;
265 };
266 
267 /// \brief Verify whether a node's pretty print matches a given string.
268 class PrintVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
269 public:
expectString(const std::string & Str)270   void expectString(const std::string &Str) {
271     ExpectString = Str;
272   }
273 
274 protected:
verify(const MatchFinder::MatchResult & Result,const ast_type_traits::DynTypedNode & Node)275   void verify(const MatchFinder::MatchResult &Result,
276               const ast_type_traits::DynTypedNode &Node) override {
277     std::string PrintStr;
278     llvm::raw_string_ostream Print(PrintStr);
279     Node.print(Print, Result.Context->getPrintingPolicy());
280 
281     if (Print.str() != ExpectString) {
282       std::string MsgStr;
283       llvm::raw_string_ostream Msg(MsgStr);
284       Msg << "Expected pretty print <" << ExpectString << ">, found <"
285           << Print.str() << '>';
286       this->setFailure(Msg.str());
287     }
288   }
289 
290 private:
291   std::string ExpectString;
292 };
293 
294 } // end namespace ast_matchers
295 } // end namespace clang
296 
297 #endif
298