1 //===- unittest/Tooling/RefactoringCallbacksTest.cpp ----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "clang/Tooling/RefactoringCallbacks.h"
11 #include "RewriterTestContext.h"
12 #include "clang/ASTMatchers/ASTMatchFinder.h"
13 #include "clang/ASTMatchers/ASTMatchers.h"
14 #include "gtest/gtest.h"
15 
16 namespace clang {
17 namespace tooling {
18 
19 using namespace ast_matchers;
20 
21 template <typename T>
expectRewritten(const std::string & Code,const std::string & Expected,const T & AMatcher,RefactoringCallback & Callback)22 void expectRewritten(const std::string &Code,
23                      const std::string &Expected,
24                      const T &AMatcher,
25                      RefactoringCallback &Callback) {
26   MatchFinder Finder;
27   Finder.addMatcher(AMatcher, &Callback);
28   std::unique_ptr<tooling::FrontendActionFactory> Factory(
29       tooling::newFrontendActionFactory(&Finder));
30   ASSERT_TRUE(tooling::runToolOnCode(Factory->create(), Code))
31       << "Parsing error in \"" << Code << "\"";
32   RewriterTestContext Context;
33   FileID ID = Context.createInMemoryFile("input.cc", Code);
34   EXPECT_TRUE(tooling::applyAllReplacements(Callback.getReplacements(),
35                                             Context.Rewrite));
36   EXPECT_EQ(Expected, Context.getRewrittenText(ID));
37 }
38 
TEST(RefactoringCallbacksTest,ReplacesStmtsWithString)39 TEST(RefactoringCallbacksTest, ReplacesStmtsWithString) {
40   std::string Code = "void f() { int i = 1; }";
41   std::string Expected = "void f() { ; }";
42   ReplaceStmtWithText Callback("id", ";");
43   expectRewritten(Code, Expected, id("id", declStmt()), Callback);
44 }
45 
TEST(RefactoringCallbacksTest,ReplacesStmtsInCalledMacros)46 TEST(RefactoringCallbacksTest, ReplacesStmtsInCalledMacros) {
47   std::string Code = "#define A void f() { int i = 1; }\nA";
48   std::string Expected = "#define A void f() { ; }\nA";
49   ReplaceStmtWithText Callback("id", ";");
50   expectRewritten(Code, Expected, id("id", declStmt()), Callback);
51 }
52 
TEST(RefactoringCallbacksTest,IgnoresStmtsInUncalledMacros)53 TEST(RefactoringCallbacksTest, IgnoresStmtsInUncalledMacros) {
54   std::string Code = "#define A void f() { int i = 1; }";
55   std::string Expected = "#define A void f() { int i = 1; }";
56   ReplaceStmtWithText Callback("id", ";");
57   expectRewritten(Code, Expected, id("id", declStmt()), Callback);
58 }
59 
TEST(RefactoringCallbacksTest,ReplacesInteger)60 TEST(RefactoringCallbacksTest, ReplacesInteger) {
61   std::string Code = "void f() { int i = 1; }";
62   std::string Expected = "void f() { int i = 2; }";
63   ReplaceStmtWithText Callback("id", "2");
64   expectRewritten(Code, Expected, id("id", expr(integerLiteral())),
65                   Callback);
66 }
67 
TEST(RefactoringCallbacksTest,ReplacesStmtWithStmt)68 TEST(RefactoringCallbacksTest, ReplacesStmtWithStmt) {
69   std::string Code = "void f() { int i = false ? 1 : i * 2; }";
70   std::string Expected = "void f() { int i = i * 2; }";
71   ReplaceStmtWithStmt Callback("always-false", "should-be");
72   expectRewritten(Code, Expected,
73       id("always-false", conditionalOperator(
74           hasCondition(cxxBoolLiteral(equals(false))),
75           hasFalseExpression(id("should-be", expr())))),
76       Callback);
77 }
78 
TEST(RefactoringCallbacksTest,ReplacesIfStmt)79 TEST(RefactoringCallbacksTest, ReplacesIfStmt) {
80   std::string Code = "bool a; void f() { if (a) f(); else a = true; }";
81   std::string Expected = "bool a; void f() { f(); }";
82   ReplaceIfStmtWithItsBody Callback("id", true);
83   expectRewritten(Code, Expected,
84       id("id", ifStmt(
85           hasCondition(implicitCastExpr(hasSourceExpression(
86               declRefExpr(to(varDecl(hasName("a"))))))))),
87       Callback);
88 }
89 
TEST(RefactoringCallbacksTest,RemovesEntireIfOnEmptyElse)90 TEST(RefactoringCallbacksTest, RemovesEntireIfOnEmptyElse) {
91   std::string Code = "void f() { if (false) int i = 0; }";
92   std::string Expected = "void f() {  }";
93   ReplaceIfStmtWithItsBody Callback("id", false);
94   expectRewritten(Code, Expected,
95       id("id", ifStmt(hasCondition(cxxBoolLiteral(equals(false))))),
96       Callback);
97 }
98 
99 } // end namespace ast_matchers
100 } // end namespace clang
101