1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <assert.h>
6 #include <stdlib.h>
7 #include <algorithm>
8 #include <memory>
9 #include <string>
10 
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/ParentMap.h"
13 #include "clang/ASTMatchers/ASTMatchFinder.h"
14 #include "clang/ASTMatchers/ASTMatchers.h"
15 #include "clang/ASTMatchers/ASTMatchersMacros.h"
16 #include "clang/Analysis/CFG.h"
17 #include "clang/Basic/SourceManager.h"
18 #include "clang/Frontend/FrontendActions.h"
19 #include "clang/Lex/Lexer.h"
20 #include "clang/Tooling/CommonOptionsParser.h"
21 #include "clang/Tooling/Refactoring.h"
22 #include "clang/Tooling/Tooling.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/TargetSelect.h"
25 
26 using Replacements = std::vector<clang::tooling::Replacement>;
27 using clang::ASTContext;
28 using clang::CFG;
29 using clang::CFGBlock;
30 using clang::CFGLifetimeEnds;
31 using clang::CFGStmt;
32 using clang::CallExpr;
33 using clang::Decl;
34 using clang::DeclRefExpr;
35 using clang::FunctionDecl;
36 using clang::LambdaExpr;
37 using clang::Stmt;
38 using clang::UnaryOperator;
39 using clang::ast_type_traits::DynTypedNode;
40 using clang::tooling::CommonOptionsParser;
41 using namespace clang::ast_matchers;
42 
43 namespace {
44 
45 class Rewriter {
46  public:
~Rewriter()47   virtual ~Rewriter() {}
48 };
49 
50 // Removes unneeded base::Passed() on a parameter of base::BindOnce().
51 // Example:
52 //   // Before
53 //   base::BindOnce(&Foo, base::Passed(&bar));
54 //   base::BindOnce(&Foo, base::Passed(std::move(baz)));
55 //   base::BindOnce(&Foo, base::Passed(qux));
56 //
57 //   // After
58 //   base::BindOnce(&Foo, std::move(bar));
59 //   base::BindOnce(&Foo, std::move(baz));
60 //   base::BindOnce(&Foo, std::move(*qux));
61 class PassedToMoveRewriter : public MatchFinder::MatchCallback,
62                              public Rewriter {
63  public:
PassedToMoveRewriter(Replacements * replacements)64   explicit PassedToMoveRewriter(Replacements* replacements)
65       : replacements_(replacements) {}
66 
GetMatcher()67   StatementMatcher GetMatcher() {
68     auto is_passed = namedDecl(hasName("::base::Passed"));
69     auto is_bind_once_call = callee(namedDecl(hasName("::base::BindOnce")));
70 
71     // Matches base::Passed() call on a base::BindOnce() argument.
72     return callExpr(is_bind_once_call,
73                     hasAnyArgument(ignoringImplicit(
74                         callExpr(callee(is_passed)).bind("target"))));
75   }
76 
run(const MatchFinder::MatchResult & result)77   void run(const MatchFinder::MatchResult& result) override {
78     auto* target = result.Nodes.getNodeAs<CallExpr>("target");
79     auto* callee = target->getCallee()->IgnoreImpCasts();
80 
81     auto* callee_decl = clang::dyn_cast<DeclRefExpr>(callee)->getDecl();
82     auto* passed_decl = clang::dyn_cast<FunctionDecl>(callee_decl);
83     auto* param_type = passed_decl->getParamDecl(0)->getType().getTypePtr();
84 
85     if (param_type->isRValueReferenceType()) {
86       // base::Passed(xxx) -> xxx.
87       // The parameter type is already an rvalue reference.
88       // Example:
89       //   std::unique_ptr<int> foo();
90       //   std::unique_ptr<int> bar;
91       //   base::Passed(foo());
92       //   base::Passed(std::move(bar));
93       // In these cases, we can just remove base::Passed.
94       auto left = clang::CharSourceRange::getTokenRange(
95           result.SourceManager->getSpellingLoc(target->getLocStart()),
96           result.SourceManager->getSpellingLoc(target->getArg(0)->getExprLoc())
97               .getLocWithOffset(-1));
98       auto r_paren = clang::CharSourceRange::getTokenRange(
99           result.SourceManager->getSpellingLoc(target->getRParenLoc()),
100           result.SourceManager->getSpellingLoc(target->getRParenLoc()));
101       replacements_->emplace_back(*result.SourceManager, left, " ");
102       replacements_->emplace_back(*result.SourceManager, r_paren, " ");
103       return;
104     }
105 
106     if (!param_type->isPointerType())
107       return;
108 
109     auto* passed_arg = target->getArg(0)->IgnoreImpCasts();
110     if (auto* unary = clang::dyn_cast<clang::UnaryOperator>(passed_arg)) {
111       if (unary->getOpcode() == clang::UO_AddrOf) {
112         // base::Passed(&xxx) -> std::move(xxx).
113         auto left = clang::CharSourceRange::getTokenRange(
114             result.SourceManager->getSpellingLoc(target->getLocStart()),
115             result.SourceManager->getSpellingLoc(
116                 target->getArg(0)->getExprLoc()));
117         replacements_->emplace_back(*result.SourceManager, left, "std::move(");
118         return;
119       }
120     }
121 
122     // base::Passed(xxx) -> std::move(*xxx)
123     auto left = clang::CharSourceRange::getTokenRange(
124         result.SourceManager->getSpellingLoc(target->getLocStart()),
125         result.SourceManager->getSpellingLoc(target->getArg(0)->getExprLoc())
126             .getLocWithOffset(-1));
127     replacements_->emplace_back(*result.SourceManager, left, "std::move(*");
128   }
129 
130  private:
131   Replacements* replacements_;
132 };
133 
134 // Replace base::Bind() to base::BindOnce() where resulting base::Callback is
135 // implicitly converted into base::OnceCallback.
136 // Example:
137 //   // Before
138 //   base::PostTask(FROM_HERE, base::Bind(&Foo));
139 //   base::OnceCallback<void()> cb = base::Bind(&Foo);
140 //
141 //   // After
142 //   base::PostTask(FROM_HERE, base::BindOnce(&Foo));
143 //   base::OnceCallback<void()> cb = base::BindOnce(&Foo);
144 class BindOnceRewriter : public MatchFinder::MatchCallback, public Rewriter {
145  public:
BindOnceRewriter(Replacements * replacements)146   explicit BindOnceRewriter(Replacements* replacements)
147       : replacements_(replacements) {}
148 
GetMatcher()149   StatementMatcher GetMatcher() {
150     auto is_once_callback = hasType(hasCanonicalType(hasDeclaration(
151         classTemplateSpecializationDecl(hasName("::base::OnceCallback")))));
152     auto is_repeating_callback =
153         hasType(hasCanonicalType(hasDeclaration(classTemplateSpecializationDecl(
154             hasName("::base::RepeatingCallback")))));
155 
156     auto bind_call =
157         callExpr(callee(namedDecl(hasName("::base::Bind")))).bind("target");
158     auto parameter_construction =
159         cxxConstructExpr(is_repeating_callback, argumentCountIs(1),
160                          hasArgument(0, ignoringImplicit(bind_call)));
161     auto constructor_conversion = cxxConstructExpr(
162         is_once_callback, argumentCountIs(1),
163         hasArgument(0, ignoringImplicit(parameter_construction)));
164     return implicitCastExpr(is_once_callback,
165                             hasSourceExpression(constructor_conversion));
166   }
167 
run(const MatchFinder::MatchResult & result)168   void run(const MatchFinder::MatchResult& result) override {
169     auto* target = result.Nodes.getNodeAs<clang::CallExpr>("target");
170     auto* callee = target->getCallee();
171     auto range = clang::CharSourceRange::getTokenRange(
172         result.SourceManager->getSpellingLoc(callee->getLocEnd()),
173         result.SourceManager->getSpellingLoc(callee->getLocEnd()));
174     replacements_->emplace_back(*result.SourceManager, range, "BindOnce");
175   }
176 
177  private:
178   Replacements* replacements_;
179 };
180 
181 // Converts pass-by-const-ref base::Callback's to pass-by-value.
182 // Example:
183 //   // Before
184 //   using BarCallback = base::Callback<void(void*)>;
185 //   void Foo(const base::Callback<void(int)>& cb);
186 //   void Bar(const BarCallback& cb);
187 //
188 //   // After
189 //   using BarCallback = base::Callback<void(void*)>;
190 //   void Foo(base::Callback<void(int)> cb);
191 //   void Bar(BarCallback cb);
192 class PassByValueRewriter : public MatchFinder::MatchCallback, public Rewriter {
193  public:
PassByValueRewriter(Replacements * replacements)194   explicit PassByValueRewriter(Replacements* replacements)
195       : replacements_(replacements) {}
196 
GetMatcher()197   DeclarationMatcher GetMatcher() {
198     auto is_repeating_callback =
199         namedDecl(hasName("::base::RepeatingCallback"));
200     return parmVarDecl(
201                hasType(hasCanonicalType(references(is_repeating_callback))))
202         .bind("target");
203   }
204 
run(const MatchFinder::MatchResult & result)205   void run(const MatchFinder::MatchResult& result) override {
206     auto* target = result.Nodes.getNodeAs<clang::ParmVarDecl>("target");
207     auto qual_type = target->getType();
208     auto* ref_type =
209         clang::dyn_cast<clang::LValueReferenceType>(qual_type.getTypePtr());
210     if (!ref_type || !ref_type->getPointeeType().isLocalConstQualified())
211       return;
212 
213     // Remove the leading `const` and the following `&`.
214     auto type_loc = target->getTypeSourceInfo()->getTypeLoc();
215     auto const_keyword = clang::CharSourceRange::getTokenRange(
216         result.SourceManager->getSpellingLoc(target->getLocStart()),
217         result.SourceManager->getSpellingLoc(target->getLocStart()));
218     auto lvalue_ref = clang::CharSourceRange::getTokenRange(
219         result.SourceManager->getSpellingLoc(type_loc.getLocEnd()),
220         result.SourceManager->getSpellingLoc(type_loc.getLocEnd()));
221     replacements_->emplace_back(*result.SourceManager, const_keyword, " ");
222     replacements_->emplace_back(*result.SourceManager, lvalue_ref, " ");
223   }
224 
225  private:
226   Replacements* replacements_;
227 };
228 
229 // Adds std::move() to base::RepeatingCallback<> where it looks relevant.
230 // Example:
231 //   // Before
232 //   void Foo(base::Callback<void(int)> cb1) {
233 //     base::Closure cb2 = base::Bind(cb1, 42);
234 //     PostTask(FROM_HERE, cb2);
235 //   }
236 //
237 //   // After
238 //   void Foo(base::Callback<void(int)> cb1) {
239 //     base::Closure cb2 = base::Bind(std::move(cb1), 42);
240 //     PostTask(FROM_HERE, std::move(cb2));
241 //   }
242 class AddStdMoveRewriter : public MatchFinder::MatchCallback, public Rewriter {
243  public:
AddStdMoveRewriter(Replacements * replacements)244   explicit AddStdMoveRewriter(Replacements* replacements)
245       : replacements_(replacements) {}
246 
GetMatcher()247   StatementMatcher GetMatcher() {
248     return declRefExpr(
249                hasType(hasCanonicalType(hasDeclaration(
250                    namedDecl(hasName("::base::RepeatingCallback"))))),
251                anyOf(hasAncestor(cxxConstructorDecl().bind("enclosing_ctor")),
252                      hasAncestor(functionDecl().bind("enclosing_func")),
253                      hasAncestor(lambdaExpr().bind("enclosing_lambda"))))
254         .bind("target");
255   }
256 
257   // Build Control Flow Graph (CFG) for |stmt| and populate class members with
258   // the content of the graph. Returns true if the analysis finished
259   // successfully.
ExtractCFGContentToMembers(Stmt * stmt,ASTContext * context)260   bool ExtractCFGContentToMembers(Stmt* stmt, ASTContext* context) {
261     // Try to make a cache entry. The failure implies it's already in the cache.
262     auto inserted = cfg_cache_.emplace(stmt, nullptr);
263     if (!inserted.second)
264       return !!inserted.first->second;
265 
266     std::unique_ptr<CFG>& cfg = inserted.first->second;
267     CFG::BuildOptions opts;
268     opts.AddInitializers = true;
269     opts.AddLifetime = true;
270     opts.AddStaticInitBranches = true;
271     cfg = CFG::buildCFG(nullptr, stmt, context, opts);
272 
273     // CFG construction may fail. Report it to the caller.
274     if (!cfg)
275       return false;
276     if (!parent_map_)
277       parent_map_ = llvm::make_unique<clang::ParentMap>(stmt);
278     else
279       parent_map_->addStmt(stmt);
280 
281     // Populate |top_stmts_|, that contains Stmts that is evaluated in its own
282     // CFGElement.
283     for (auto* block : *cfg) {
284       for (auto& elem : *block) {
285         if (auto stmt = elem.getAs<CFGStmt>())
286           top_stmts_.insert(stmt->getStmt());
287       }
288     }
289 
290     // Populate |enclosing_block_|, that maps a Stmt to a CFGBlock that contains
291     // the Stmt.
292     std::function<void(const CFGBlock*, const Stmt*)> recursive_set_enclosing =
293         [&](const CFGBlock* block, const Stmt* stmt) {
294           enclosing_block_[stmt] = block;
295           for (auto* c : stmt->children()) {
296             if (!c)
297               continue;
298             if (top_stmts_.find(c) != top_stmts_.end())
299               continue;
300             recursive_set_enclosing(block, c);
301           }
302         };
303     for (auto* block : *cfg) {
304       for (auto& elem : *block) {
305         if (auto stmt = elem.getAs<CFGStmt>())
306           recursive_set_enclosing(block, stmt->getStmt());
307       }
308     }
309 
310     return true;
311   }
312 
EnclosingCxxStatement(const Stmt * stmt)313   const Stmt* EnclosingCxxStatement(const Stmt* stmt) {
314     while (true) {
315       const Stmt* parent = parent_map_->getParentIgnoreParenCasts(stmt);
316       assert(parent);
317       switch (parent->getStmtClass()) {
318         case Stmt::CompoundStmtClass:
319         case Stmt::ForStmtClass:
320         case Stmt::CXXForRangeStmtClass:
321         case Stmt::WhileStmtClass:
322         case Stmt::DoStmtClass:
323         case Stmt::IfStmtClass:
324 
325           // Other candidates:
326           //   Stmt::CXXTryStmtClass
327           //   Stmt::CXXCatchStmtClass
328           //   Stmt::CapturedStmtClass
329           //   Stmt::SwitchStmtClass
330           //   Stmt::SwitchCaseClass
331           return stmt;
332         default:
333           stmt = parent;
334           break;
335       }
336     }
337   }
338 
WasPointerTaken(const Stmt * stmt,const Decl * decl)339   bool WasPointerTaken(const Stmt* stmt, const Decl* decl) {
340     std::function<bool(const Stmt*)> visit_stmt = [&](const Stmt* stmt) {
341       if (auto* op = clang::dyn_cast<UnaryOperator>(stmt)) {
342         if (op->getOpcode() == clang::UO_AddrOf) {
343           auto* ref = clang::dyn_cast<DeclRefExpr>(op->getSubExpr());
344           // |ref| may be null if the sub-expr has a dependent type.
345           if (ref && ref->getDecl() == decl)
346             return true;
347         }
348       }
349 
350       for (auto* c : stmt->children()) {
351         if (!c)
352           continue;
353         if (visit_stmt(c))
354           return true;
355       }
356       return false;
357     };
358     return visit_stmt(stmt);
359   }
360 
HasCapturingLambda(const Stmt * stmt,const Decl * decl)361   bool HasCapturingLambda(const Stmt* stmt, const Decl* decl) {
362     std::function<bool(const Stmt*)> visit_stmt = [&](const Stmt* stmt) {
363       if (auto* l = clang::dyn_cast<LambdaExpr>(stmt)) {
364         for (auto c : l->captures()) {
365           if (c.getCapturedVar() == decl)
366             return true;
367         }
368       }
369 
370       for (auto* c : stmt->children()) {
371         if (!c)
372           continue;
373 
374         if (visit_stmt(c))
375           return true;
376       }
377 
378       return false;
379     };
380     return visit_stmt(stmt);
381   }
382 
383   // Returns true if there are multiple occurrences to |decl| in one of C++
384   // statements in |stmt|.
HasUnorderedOccurrences(const Decl * decl,const Stmt * stmt)385   bool HasUnorderedOccurrences(const Decl* decl, const Stmt* stmt) {
386     int count = 0;
387     std::function<void(const Stmt*)> visit_stmt = [&](const Stmt* s) {
388       if (auto* ref = clang::dyn_cast<DeclRefExpr>(s)) {
389         if (ref->getDecl() == decl)
390           ++count;
391       }
392       for (auto* c : s->children()) {
393         if (!c)
394           continue;
395         visit_stmt(c);
396       }
397     };
398 
399     visit_stmt(EnclosingCxxStatement(stmt));
400     return count > 1;
401   }
402 
run(const MatchFinder::MatchResult & result)403   void run(const MatchFinder::MatchResult& result) override {
404     auto* target = result.Nodes.getNodeAs<clang::DeclRefExpr>("target");
405     auto* decl = clang::dyn_cast<clang::VarDecl>(target->getDecl());
406 
407     // Other than local variables and parameters are out-of-scope.
408     if (!decl || !decl->isLocalVarDeclOrParm())
409       return;
410 
411     auto qual_type = decl->getType();
412     // Qualified variables are out-of-scope. They are likely not movable.
413     if (qual_type.getCanonicalType().hasQualifiers())
414       return;
415 
416     auto* type = qual_type.getTypePtr();
417     // References and pointers are out-of-scope.
418     if (type->isReferenceType() || type->isPointerType())
419       return;
420 
421     Stmt* body = nullptr;
422     if (auto* ctor = result.Nodes.getNodeAs<LambdaExpr>("enclosing_ctor"))
423       return;  // Skip constructor case for now. TBD.
424     else if (auto* func =
425                  result.Nodes.getNodeAs<FunctionDecl>("enclosing_func"))
426       body = func->getBody();
427     else if (auto* lambda =
428                  result.Nodes.getNodeAs<LambdaExpr>("enclosing_lambda"))
429       body = lambda->getBody();
430     else
431       return;
432 
433     // Disable the replacement if there is a lambda that captures |decl|.
434     if (HasCapturingLambda(body, decl))
435       return;
436 
437     // Disable the replacement if the pointer to |decl| is taken in the scope.
438     if (WasPointerTaken(body, decl))
439       return;
440 
441     if (!ExtractCFGContentToMembers(body, result.Context))
442       return;
443 
444     auto* parent = parent_map_->getParentIgnoreParenCasts(target);
445     if (auto* p = clang::dyn_cast<CallExpr>(parent)) {
446       auto* callee = p->getCalleeDecl();
447       // |callee| may be null if the CallExpr has an unresolved look up.
448       if (!callee)
449         return;
450       auto* callee_decl = clang::dyn_cast<clang::NamedDecl>(callee);
451       auto name = callee_decl->getQualifiedNameAsString();
452 
453       // Disable the replacement if it's already in std::move() or
454       // std::forward().
455       if (name == "std::__1::move" || name == "std::__1::forward")
456         return;
457     } else if (parent->getStmtClass() == Stmt::ReturnStmtClass) {
458       // Disable the replacement if it's in a return statement.
459       return;
460     }
461 
462     // If the same C++ statement contains multiple reference to the variable,
463     // don't insert std::move() to be conservative.
464     if (HasUnorderedOccurrences(decl, target))
465       return;
466 
467     bool saw_reuse = false;
468     ForEachFollowingStmts(target, [&](const Stmt* stmt) {
469       if (auto* ref = clang::dyn_cast<DeclRefExpr>(stmt)) {
470         if (ref->getDecl() == decl) {
471           saw_reuse = true;
472           return false;
473         }
474       }
475 
476       // TODO: Detect Reset() and operator=() to stop the traversal.
477       return true;
478     });
479     if (saw_reuse)
480       return;
481 
482     replacements_->emplace_back(
483         *result.SourceManager,
484         result.SourceManager->getSpellingLoc(target->getLocStart()), 0,
485         "std::move(");
486     replacements_->emplace_back(
487         *result.SourceManager,
488         clang::Lexer::getLocForEndOfToken(target->getLocEnd(), 0,
489                                           *result.SourceManager,
490                                           result.Context->getLangOpts()),
491         0, ")");
492   }
493 
494   // Invokes |handler| for each Stmt that follows |target| until it reaches the
495   // end of the lifetime of the variable that |target| references.
496   // If |handler| returns false, stops following the current control flow.
ForEachFollowingStmts(const DeclRefExpr * target,std::function<bool (const Stmt *)> handler)497   void ForEachFollowingStmts(const DeclRefExpr* target,
498                              std::function<bool(const Stmt*)> handler) {
499     auto* decl = target->getDecl();
500     auto* block = enclosing_block_[target];
501 
502     std::set<const clang::CFGBlock*> visited;
503     std::vector<const clang::CFGBlock*> stack = {block};
504 
505     bool saw_target = false;
506     std::function<bool(const Stmt*)> visit_stmt = [&](const Stmt* s) {
507       for (auto* t : s->children()) {
508         if (!t)
509           continue;
510 
511         // |t| is evaluated elsewhere if a sub-Stmt is in |top_stmt_|.
512         if (top_stmts_.find(t) != top_stmts_.end())
513           continue;
514 
515         if (!visit_stmt(t))
516           return false;
517       }
518 
519       if (!saw_target) {
520         if (s == target)
521           saw_target = true;
522         return true;
523       }
524 
525       return handler(s);
526     };
527 
528     bool visited_initial_block_twice = false;
529     while (!stack.empty()) {
530       auto* b = stack.back();
531       stack.pop_back();
532       if (!visited.insert(b).second) {
533         if (b != block || visited_initial_block_twice)
534           continue;
535         visited_initial_block_twice = true;
536       }
537 
538       bool cont = true;
539       for (auto e : *b) {
540         if (auto s = e.getAs<CFGStmt>()) {
541           if (!visit_stmt(s->getStmt())) {
542             cont = false;
543             break;
544           }
545         } else if (auto l = e.getAs<CFGLifetimeEnds>()) {
546           if (l->getVarDecl() == decl) {
547             cont = false;
548             break;
549           }
550         }
551       }
552 
553       if (cont) {
554         for (auto s : b->succs()) {
555           if (!s)
556             continue;  // Unreachable block.
557           stack.push_back(s);
558         }
559       }
560     }
561   }
562 
563  private:
564   // Function body to CFG.
565   std::map<const Stmt*, std::unique_ptr<CFG>> cfg_cache_;
566 
567   // Statement to the enclosing CFGBlock.
568   std::map<const Stmt*, const CFGBlock*> enclosing_block_;
569 
570   // Stmt to its parent Stmt.
571   std::unique_ptr<clang::ParentMap> parent_map_;
572 
573   // A set of Stmt that a CFGElement has it directly.
574   std::set<const Stmt*> top_stmts_;
575 
576   Replacements* replacements_;
577 };
578 
579 // Remove base::AdaptCallbackForRepeating() where resulting
580 // base::RepeatingCallback is implicitly converted into base::OnceCallback.
581 // Example:
582 //   // Before
583 //   base::PostTask(
584 //       FROM_HERE,
585 //       base::AdaptCallbackForRepeating(base::BindOnce(&Foo)));
586 //   base::OnceCallback<void()> cb = base::AdaptCallbackForRepeating(
587 //       base::OnceBind(&Foo));
588 //
589 //   // After
590 //   base::PostTask(FROM_HERE, base::BindOnce(&Foo));
591 //   base::OnceCallback<void()> cb = base::BindOnce(&Foo);
592 class AdaptCallbackForRepeatingRewriter : public MatchFinder::MatchCallback,
593                                           public Rewriter {
594  public:
AdaptCallbackForRepeatingRewriter(Replacements * replacements)595   explicit AdaptCallbackForRepeatingRewriter(Replacements* replacements)
596       : replacements_(replacements) {}
597 
GetMatcher()598   StatementMatcher GetMatcher() {
599     auto is_once_callback = hasType(hasCanonicalType(hasDeclaration(
600         classTemplateSpecializationDecl(hasName("::base::OnceCallback")))));
601     auto is_repeating_callback =
602         hasType(hasCanonicalType(hasDeclaration(classTemplateSpecializationDecl(
603             hasName("::base::RepeatingCallback")))));
604 
605     auto adapt_callback_call =
606         callExpr(
607             callee(namedDecl(hasName("::base::AdaptCallbackForRepeating"))))
608             .bind("target");
609     auto parameter_construction =
610         cxxConstructExpr(is_repeating_callback, argumentCountIs(1),
611                          hasArgument(0, ignoringImplicit(adapt_callback_call)));
612     auto constructor_conversion = cxxConstructExpr(
613         is_once_callback, argumentCountIs(1),
614         hasArgument(0, ignoringImplicit(parameter_construction)));
615     return implicitCastExpr(is_once_callback,
616                             hasSourceExpression(constructor_conversion));
617   }
618 
run(const MatchFinder::MatchResult & result)619   void run(const MatchFinder::MatchResult& result) override {
620     auto* target = result.Nodes.getNodeAs<clang::CallExpr>("target");
621 
622     auto left = clang::CharSourceRange::getTokenRange(
623         result.SourceManager->getSpellingLoc(target->getLocStart()),
624         result.SourceManager->getSpellingLoc(target->getArg(0)->getExprLoc())
625             .getLocWithOffset(-1));
626 
627     // We use " " as replacement to work around https://crbug.com/861886.
628     replacements_->emplace_back(*result.SourceManager, left, " ");
629     auto r_paren = clang::CharSourceRange::getTokenRange(
630         result.SourceManager->getSpellingLoc(target->getRParenLoc()),
631         result.SourceManager->getSpellingLoc(target->getRParenLoc()));
632     replacements_->emplace_back(*result.SourceManager, r_paren, " ");
633   }
634 
635  private:
636   Replacements* replacements_;
637 };
638 
639 llvm::cl::extrahelp common_help(CommonOptionsParser::HelpMessage);
640 llvm::cl::OptionCategory rewriter_category("Rewriter Options");
641 
642 llvm::cl::opt<std::string> rewriter_option(
643     "rewriter",
644     llvm::cl::desc(R"(One of the name of rewriter to apply.
645 Available rewriters are:
646     remove_unneeded_passed
647     bind_to_bind_once
648     pass_by_value
649     add_std_move
650     remove_unneeded_adapt_callback
651 The default is remove_unneeded_passed.
652 )"),
653     llvm::cl::init("remove_unneeded_passed"),
654     llvm::cl::cat(rewriter_category));
655 
656 }  // namespace.
657 
main(int argc,const char * argv[])658 int main(int argc, const char* argv[]) {
659   llvm::InitializeNativeTarget();
660   llvm::InitializeNativeTargetAsmParser();
661   CommonOptionsParser options(argc, argv, rewriter_category);
662   clang::tooling::ClangTool tool(options.getCompilations(),
663                                  options.getSourcePathList());
664 
665   MatchFinder match_finder;
666   std::vector<clang::tooling::Replacement> replacements;
667 
668   std::unique_ptr<Rewriter> rewriter;
669   if (rewriter_option == "remove_unneeded_passed") {
670     auto passed_to_move =
671         llvm::make_unique<PassedToMoveRewriter>(&replacements);
672     match_finder.addMatcher(passed_to_move->GetMatcher(), passed_to_move.get());
673     rewriter = std::move(passed_to_move);
674   } else if (rewriter_option == "bind_to_bind_once") {
675     auto bind_once = llvm::make_unique<BindOnceRewriter>(&replacements);
676     match_finder.addMatcher(bind_once->GetMatcher(), bind_once.get());
677     rewriter = std::move(bind_once);
678   } else if (rewriter_option == "pass_by_value") {
679     auto pass_by_value = llvm::make_unique<PassByValueRewriter>(&replacements);
680     match_finder.addMatcher(pass_by_value->GetMatcher(), pass_by_value.get());
681     rewriter = std::move(pass_by_value);
682   } else if (rewriter_option == "add_std_move") {
683     auto add_std_move = llvm::make_unique<AddStdMoveRewriter>(&replacements);
684     match_finder.addMatcher(add_std_move->GetMatcher(), add_std_move.get());
685     rewriter = std::move(add_std_move);
686   } else if (rewriter_option == "remove_unneeded_adapt_callback") {
687     auto remove_unneeded_adapt_callback =
688         llvm::make_unique<AdaptCallbackForRepeatingRewriter>(&replacements);
689     match_finder.addMatcher(remove_unneeded_adapt_callback->GetMatcher(),
690                             remove_unneeded_adapt_callback.get());
691     rewriter = std::move(remove_unneeded_adapt_callback);
692   } else {
693     abort();
694   }
695 
696   std::unique_ptr<clang::tooling::FrontendActionFactory> factory =
697       clang::tooling::newFrontendActionFactory(&match_finder);
698   int result = tool.run(factory.get());
699   if (result != 0)
700     return result;
701 
702   // Serialization format is documented in tools/clang/scripts/run_tool.py
703   llvm::outs() << "==== BEGIN EDITS ====\n";
704   for (const auto& r : replacements) {
705     std::string replacement_text = r.getReplacementText().str();
706     std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
707     llvm::outs() << "r:::" << r.getFilePath() << ":::" << r.getOffset()
708                  << ":::" << r.getLength() << ":::" << replacement_text << "\n";
709   }
710   llvm::outs() << "==== END EDITS ====\n";
711 
712   return 0;
713 }
714