1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 //  Provides MatchVerifier, a base class to implement gtest matchers that
11 //  verify things that can be matched on the AST.
12 //
13 //  Also implements matchers based on MatchVerifier:
14 //  LocationVerifier and RangeVerifier to verify whether a matched node has
15 //  the expected source location or source range.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #ifndef LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
20 #define LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
21 
22 #include "clang/AST/ASTContext.h"
23 #include "clang/ASTMatchers/ASTMatchFinder.h"
24 #include "clang/ASTMatchers/ASTMatchers.h"
25 #include "clang/Tooling/Tooling.h"
26 #include "gtest/gtest.h"
27 
28 namespace clang {
29 namespace ast_matchers {
30 
31 enum Language {
32     Lang_C,
33     Lang_C89,
34     Lang_CXX,
35     Lang_CXX11,
36     Lang_OpenCL,
37     Lang_OBJCXX
38 };
39 
40 /// \brief Base class for verifying some property of nodes found by a matcher.
41 template <typename NodeType>
42 class MatchVerifier : public MatchFinder::MatchCallback {
43 public:
44   template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher)45   testing::AssertionResult match(const std::string &Code,
46                                  const MatcherType &AMatcher) {
47     std::vector<std::string> Args;
48     return match(Code, AMatcher, Args, Lang_CXX);
49   }
50 
51   template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher,Language L)52   testing::AssertionResult match(const std::string &Code,
53                                  const MatcherType &AMatcher,
54                                  Language L) {
55     std::vector<std::string> Args;
56     return match(Code, AMatcher, Args, L);
57   }
58 
59   template <typename MatcherType>
60   testing::AssertionResult match(const std::string &Code,
61                                  const MatcherType &AMatcher,
62                                  std::vector<std::string>& Args,
63                                  Language L);
64 
65   template <typename MatcherType>
66   testing::AssertionResult match(const Decl *D, const MatcherType &AMatcher);
67 
68 protected:
69   void run(const MatchFinder::MatchResult &Result) override;
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)70   virtual void verify(const MatchFinder::MatchResult &Result,
71                       const NodeType &Node) {}
72 
setFailure(const Twine & Result)73   void setFailure(const Twine &Result) {
74     Verified = false;
75     VerifyResult = Result.str();
76   }
77 
setSuccess()78   void setSuccess() {
79     Verified = true;
80   }
81 
82 private:
83   bool Verified;
84   std::string VerifyResult;
85 };
86 
87 /// \brief Runs a matcher over some code, and returns the result of the
88 /// verifier for the matched node.
89 template <typename NodeType> template <typename MatcherType>
match(const std::string & Code,const MatcherType & AMatcher,std::vector<std::string> & Args,Language L)90 testing::AssertionResult MatchVerifier<NodeType>::match(
91     const std::string &Code, const MatcherType &AMatcher,
92     std::vector<std::string>& Args, Language L) {
93   MatchFinder Finder;
94   Finder.addMatcher(AMatcher.bind(""), this);
95   std::unique_ptr<tooling::FrontendActionFactory> Factory(
96       tooling::newFrontendActionFactory(&Finder));
97 
98   StringRef FileName;
99   switch (L) {
100   case Lang_C:
101     Args.push_back("-std=c99");
102     FileName = "input.c";
103     break;
104   case Lang_C89:
105     Args.push_back("-std=c89");
106     FileName = "input.c";
107     break;
108   case Lang_CXX:
109     Args.push_back("-std=c++98");
110     FileName = "input.cc";
111     break;
112   case Lang_CXX11:
113     Args.push_back("-std=c++11");
114     FileName = "input.cc";
115     break;
116   case Lang_OpenCL:
117     FileName = "input.cl";
118     break;
119   case Lang_OBJCXX:
120     FileName = "input.mm";
121     break;
122   }
123 
124   // Default to failure in case callback is never called
125   setFailure("Could not find match");
126   if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
127     return testing::AssertionFailure() << "Parsing error";
128   if (!Verified)
129     return testing::AssertionFailure() << VerifyResult;
130   return testing::AssertionSuccess();
131 }
132 
133 /// \brief Runs a matcher over some AST, and returns the result of the
134 /// verifier for the matched node.
135 template <typename NodeType> template <typename MatcherType>
match(const Decl * D,const MatcherType & AMatcher)136 testing::AssertionResult MatchVerifier<NodeType>::match(
137     const Decl *D, const MatcherType &AMatcher) {
138   MatchFinder Finder;
139   Finder.addMatcher(AMatcher.bind(""), this);
140 
141   setFailure("Could not find match");
142   Finder.match(*D, D->getASTContext());
143 
144   if (!Verified)
145     return testing::AssertionFailure() << VerifyResult;
146   return testing::AssertionSuccess();
147 }
148 
149 template <typename NodeType>
run(const MatchFinder::MatchResult & Result)150 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
151   const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
152   if (!Node) {
153     setFailure("Matched node has wrong type");
154   } else {
155     // Callback has been called, default to success.
156     setSuccess();
157     verify(Result, *Node);
158   }
159 }
160 
161 template <>
run(const MatchFinder::MatchResult & Result)162 inline void MatchVerifier<ast_type_traits::DynTypedNode>::run(
163     const MatchFinder::MatchResult &Result) {
164   BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
165   BoundNodes::IDToNodeMap::const_iterator I = M.find("");
166   if (I == M.end()) {
167     setFailure("Node was not bound");
168   } else {
169     // Callback has been called, default to success.
170     setSuccess();
171     verify(Result, I->second);
172   }
173 }
174 
175 /// \brief Verify whether a node has the correct source location.
176 ///
177 /// By default, Node.getSourceLocation() is checked. This can be changed
178 /// by overriding getLocation().
179 template <typename NodeType>
180 class LocationVerifier : public MatchVerifier<NodeType> {
181 public:
expectLocation(unsigned Line,unsigned Column)182   void expectLocation(unsigned Line, unsigned Column) {
183     ExpectLine = Line;
184     ExpectColumn = Column;
185   }
186 
187 protected:
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)188   void verify(const MatchFinder::MatchResult &Result,
189               const NodeType &Node) override {
190     SourceLocation Loc = getLocation(Node);
191     unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
192     unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
193     if (Line != ExpectLine || Column != ExpectColumn) {
194       std::string MsgStr;
195       llvm::raw_string_ostream Msg(MsgStr);
196       Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
197           << ">, found <";
198       Loc.print(Msg, *Result.SourceManager);
199       Msg << '>';
200       this->setFailure(Msg.str());
201     }
202   }
203 
getLocation(const NodeType & Node)204   virtual SourceLocation getLocation(const NodeType &Node) {
205     return Node.getLocation();
206   }
207 
208 private:
209   unsigned ExpectLine, ExpectColumn;
210 };
211 
212 /// \brief Verify whether a node has the correct source range.
213 ///
214 /// By default, Node.getSourceRange() is checked. This can be changed
215 /// by overriding getRange().
216 template <typename NodeType>
217 class RangeVerifier : public MatchVerifier<NodeType> {
218 public:
expectRange(unsigned BeginLine,unsigned BeginColumn,unsigned EndLine,unsigned EndColumn)219   void expectRange(unsigned BeginLine, unsigned BeginColumn,
220                    unsigned EndLine, unsigned EndColumn) {
221     ExpectBeginLine = BeginLine;
222     ExpectBeginColumn = BeginColumn;
223     ExpectEndLine = EndLine;
224     ExpectEndColumn = EndColumn;
225   }
226 
227 protected:
verify(const MatchFinder::MatchResult & Result,const NodeType & Node)228   void verify(const MatchFinder::MatchResult &Result,
229               const NodeType &Node) override {
230     SourceRange R = getRange(Node);
231     SourceLocation Begin = R.getBegin();
232     SourceLocation End = R.getEnd();
233     unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
234     unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
235     unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
236     unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
237     if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
238         EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
239       std::string MsgStr;
240       llvm::raw_string_ostream Msg(MsgStr);
241       Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
242           << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
243       Begin.print(Msg, *Result.SourceManager);
244       Msg << '-';
245       End.print(Msg, *Result.SourceManager);
246       Msg << '>';
247       this->setFailure(Msg.str());
248     }
249   }
250 
getRange(const NodeType & Node)251   virtual SourceRange getRange(const NodeType &Node) {
252     return Node.getSourceRange();
253   }
254 
255 private:
256   unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
257 };
258 
259 /// \brief Verify whether a node's dump contains a given substring.
260 class DumpVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
261 public:
expectSubstring(const std::string & Str)262   void expectSubstring(const std::string &Str) {
263     ExpectSubstring = Str;
264   }
265 
266 protected:
verify(const MatchFinder::MatchResult & Result,const ast_type_traits::DynTypedNode & Node)267   void verify(const MatchFinder::MatchResult &Result,
268               const ast_type_traits::DynTypedNode &Node) override {
269     std::string DumpStr;
270     llvm::raw_string_ostream Dump(DumpStr);
271     Node.dump(Dump, *Result.SourceManager);
272 
273     if (Dump.str().find(ExpectSubstring) == std::string::npos) {
274       std::string MsgStr;
275       llvm::raw_string_ostream Msg(MsgStr);
276       Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
277           << Dump.str() << '>';
278       this->setFailure(Msg.str());
279     }
280   }
281 
282 private:
283   std::string ExpectSubstring;
284 };
285 
286 /// \brief Verify whether a node's pretty print matches a given string.
287 class PrintVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
288 public:
expectString(const std::string & Str)289   void expectString(const std::string &Str) {
290     ExpectString = Str;
291   }
292 
293 protected:
verify(const MatchFinder::MatchResult & Result,const ast_type_traits::DynTypedNode & Node)294   void verify(const MatchFinder::MatchResult &Result,
295               const ast_type_traits::DynTypedNode &Node) override {
296     std::string PrintStr;
297     llvm::raw_string_ostream Print(PrintStr);
298     Node.print(Print, Result.Context->getPrintingPolicy());
299 
300     if (Print.str() != ExpectString) {
301       std::string MsgStr;
302       llvm::raw_string_ostream Msg(MsgStr);
303       Msg << "Expected pretty print <" << ExpectString << ">, found <"
304           << Print.str() << '>';
305       this->setFailure(Msg.str());
306     }
307   }
308 
309 private:
310   std::string ExpectString;
311 };
312 
313 } // end namespace ast_matchers
314 } // end namespace clang
315 
316 #endif
317