1 //===-- SimplifyBooleanExprCheck.cpp - clang-tidy -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SimplifyBooleanExprCheck.h"
10 #include "clang/AST/RecursiveASTVisitor.h"
11 #include "clang/Lex/Lexer.h"
12 
13 #include <cassert>
14 #include <string>
15 #include <utility>
16 
17 using namespace clang::ast_matchers;
18 
19 namespace clang {
20 namespace tidy {
21 namespace readability {
22 
23 namespace {
24 
getText(const MatchFinder::MatchResult & Result,SourceRange Range)25 StringRef getText(const MatchFinder::MatchResult &Result, SourceRange Range) {
26   return Lexer::getSourceText(CharSourceRange::getTokenRange(Range),
27                               *Result.SourceManager,
28                               Result.Context->getLangOpts());
29 }
30 
31 template <typename T>
getText(const MatchFinder::MatchResult & Result,T & Node)32 StringRef getText(const MatchFinder::MatchResult &Result, T &Node) {
33   return getText(Result, Node.getSourceRange());
34 }
35 
36 const char ConditionThenStmtId[] = "if-bool-yields-then";
37 const char ConditionElseStmtId[] = "if-bool-yields-else";
38 const char TernaryId[] = "ternary-bool-yields-condition";
39 const char TernaryNegatedId[] = "ternary-bool-yields-not-condition";
40 const char IfReturnsBoolId[] = "if-return";
41 const char IfReturnsNotBoolId[] = "if-not-return";
42 const char ThenLiteralId[] = "then-literal";
43 const char IfAssignVariableId[] = "if-assign-lvalue";
44 const char IfAssignLocId[] = "if-assign-loc";
45 const char IfAssignBoolId[] = "if-assign";
46 const char IfAssignNotBoolId[] = "if-assign-not";
47 const char IfAssignVarId[] = "if-assign-var";
48 const char CompoundReturnId[] = "compound-return";
49 const char CompoundBoolId[] = "compound-bool";
50 const char CompoundNotBoolId[] = "compound-bool-not";
51 
52 const char IfStmtId[] = "if";
53 
54 const char SimplifyOperatorDiagnostic[] =
55     "redundant boolean literal supplied to boolean operator";
56 const char SimplifyConditionDiagnostic[] =
57     "redundant boolean literal in if statement condition";
58 const char SimplifyConditionalReturnDiagnostic[] =
59     "redundant boolean literal in conditional return statement";
60 
getBoolLiteral(const MatchFinder::MatchResult & Result,StringRef Id)61 const Expr *getBoolLiteral(const MatchFinder::MatchResult &Result,
62                            StringRef Id) {
63   if (const Expr *Literal = Result.Nodes.getNodeAs<CXXBoolLiteralExpr>(Id))
64     return Literal->getBeginLoc().isMacroID() ? nullptr : Literal;
65   if (const auto *Negated = Result.Nodes.getNodeAs<UnaryOperator>(Id)) {
66     if (Negated->getOpcode() == UO_LNot &&
67         isa<CXXBoolLiteralExpr>(Negated->getSubExpr()))
68       return Negated->getBeginLoc().isMacroID() ? nullptr : Negated;
69   }
70   return nullptr;
71 }
72 
literalOrNegatedBool(bool Value)73 internal::BindableMatcher<Stmt> literalOrNegatedBool(bool Value) {
74   return expr(anyOf(cxxBoolLiteral(equals(Value)),
75                     unaryOperator(hasUnaryOperand(ignoringParenImpCasts(
76                                       cxxBoolLiteral(equals(!Value)))),
77                                   hasOperatorName("!"))));
78 }
79 
returnsBool(bool Value,StringRef Id="ignored")80 internal::Matcher<Stmt> returnsBool(bool Value, StringRef Id = "ignored") {
81   auto SimpleReturnsBool = returnStmt(has(literalOrNegatedBool(Value).bind(Id)))
82                                .bind("returns-bool");
83   return anyOf(SimpleReturnsBool,
84                compoundStmt(statementCountIs(1), has(SimpleReturnsBool)));
85 }
86 
needsParensAfterUnaryNegation(const Expr * E)87 bool needsParensAfterUnaryNegation(const Expr *E) {
88   E = E->IgnoreImpCasts();
89   if (isa<BinaryOperator>(E) || isa<ConditionalOperator>(E))
90     return true;
91 
92   if (const auto *Op = dyn_cast<CXXOperatorCallExpr>(E))
93     return Op->getNumArgs() == 2 && Op->getOperator() != OO_Call &&
94            Op->getOperator() != OO_Subscript;
95 
96   return false;
97 }
98 
99 std::pair<BinaryOperatorKind, BinaryOperatorKind> Opposites[] = {
100     {BO_LT, BO_GE}, {BO_GT, BO_LE}, {BO_EQ, BO_NE}};
101 
negatedOperator(const BinaryOperator * BinOp)102 StringRef negatedOperator(const BinaryOperator *BinOp) {
103   const BinaryOperatorKind Opcode = BinOp->getOpcode();
104   for (auto NegatableOp : Opposites) {
105     if (Opcode == NegatableOp.first)
106       return BinOp->getOpcodeStr(NegatableOp.second);
107     if (Opcode == NegatableOp.second)
108       return BinOp->getOpcodeStr(NegatableOp.first);
109   }
110   return StringRef();
111 }
112 
113 std::pair<OverloadedOperatorKind, StringRef> OperatorNames[] = {
114     {OO_EqualEqual, "=="},   {OO_ExclaimEqual, "!="}, {OO_Less, "<"},
115     {OO_GreaterEqual, ">="}, {OO_Greater, ">"},       {OO_LessEqual, "<="}};
116 
getOperatorName(OverloadedOperatorKind OpKind)117 StringRef getOperatorName(OverloadedOperatorKind OpKind) {
118   for (auto Name : OperatorNames) {
119     if (Name.first == OpKind)
120       return Name.second;
121   }
122 
123   return StringRef();
124 }
125 
126 std::pair<OverloadedOperatorKind, OverloadedOperatorKind> OppositeOverloads[] =
127     {{OO_EqualEqual, OO_ExclaimEqual},
128      {OO_Less, OO_GreaterEqual},
129      {OO_Greater, OO_LessEqual}};
130 
negatedOperator(const CXXOperatorCallExpr * OpCall)131 StringRef negatedOperator(const CXXOperatorCallExpr *OpCall) {
132   const OverloadedOperatorKind Opcode = OpCall->getOperator();
133   for (auto NegatableOp : OppositeOverloads) {
134     if (Opcode == NegatableOp.first)
135       return getOperatorName(NegatableOp.second);
136     if (Opcode == NegatableOp.second)
137       return getOperatorName(NegatableOp.first);
138   }
139   return StringRef();
140 }
141 
asBool(StringRef text,bool NeedsStaticCast)142 std::string asBool(StringRef text, bool NeedsStaticCast) {
143   if (NeedsStaticCast)
144     return ("static_cast<bool>(" + text + ")").str();
145 
146   return std::string(text);
147 }
148 
needsNullPtrComparison(const Expr * E)149 bool needsNullPtrComparison(const Expr *E) {
150   if (const auto *ImpCast = dyn_cast<ImplicitCastExpr>(E))
151     return ImpCast->getCastKind() == CK_PointerToBoolean ||
152            ImpCast->getCastKind() == CK_MemberPointerToBoolean;
153 
154   return false;
155 }
156 
needsZeroComparison(const Expr * E)157 bool needsZeroComparison(const Expr *E) {
158   if (const auto *ImpCast = dyn_cast<ImplicitCastExpr>(E))
159     return ImpCast->getCastKind() == CK_IntegralToBoolean;
160 
161   return false;
162 }
163 
needsStaticCast(const Expr * E)164 bool needsStaticCast(const Expr *E) {
165   if (const auto *ImpCast = dyn_cast<ImplicitCastExpr>(E)) {
166     if (ImpCast->getCastKind() == CK_UserDefinedConversion &&
167         ImpCast->getSubExpr()->getType()->isBooleanType()) {
168       if (const auto *MemCall =
169               dyn_cast<CXXMemberCallExpr>(ImpCast->getSubExpr())) {
170         if (const auto *MemDecl =
171                 dyn_cast<CXXConversionDecl>(MemCall->getMethodDecl())) {
172           if (MemDecl->isExplicit())
173             return true;
174         }
175       }
176     }
177   }
178 
179   E = E->IgnoreImpCasts();
180   return !E->getType()->isBooleanType();
181 }
182 
compareExpressionToConstant(const MatchFinder::MatchResult & Result,const Expr * E,bool Negated,const char * Constant)183 std::string compareExpressionToConstant(const MatchFinder::MatchResult &Result,
184                                         const Expr *E, bool Negated,
185                                         const char *Constant) {
186   E = E->IgnoreImpCasts();
187   const std::string ExprText =
188       (isa<BinaryOperator>(E) ? ("(" + getText(Result, *E) + ")")
189                               : getText(Result, *E))
190           .str();
191   return ExprText + " " + (Negated ? "!=" : "==") + " " + Constant;
192 }
193 
compareExpressionToNullPtr(const MatchFinder::MatchResult & Result,const Expr * E,bool Negated)194 std::string compareExpressionToNullPtr(const MatchFinder::MatchResult &Result,
195                                        const Expr *E, bool Negated) {
196   const char *NullPtr =
197       Result.Context->getLangOpts().CPlusPlus11 ? "nullptr" : "NULL";
198   return compareExpressionToConstant(Result, E, Negated, NullPtr);
199 }
200 
compareExpressionToZero(const MatchFinder::MatchResult & Result,const Expr * E,bool Negated)201 std::string compareExpressionToZero(const MatchFinder::MatchResult &Result,
202                                     const Expr *E, bool Negated) {
203   return compareExpressionToConstant(Result, E, Negated, "0");
204 }
205 
replacementExpression(const MatchFinder::MatchResult & Result,bool Negated,const Expr * E)206 std::string replacementExpression(const MatchFinder::MatchResult &Result,
207                                   bool Negated, const Expr *E) {
208   E = E->IgnoreParenBaseCasts();
209   if (const auto *EC = dyn_cast<ExprWithCleanups>(E))
210     E = EC->getSubExpr();
211 
212   const bool NeedsStaticCast = needsStaticCast(E);
213   if (Negated) {
214     if (const auto *UnOp = dyn_cast<UnaryOperator>(E)) {
215       if (UnOp->getOpcode() == UO_LNot) {
216         if (needsNullPtrComparison(UnOp->getSubExpr()))
217           return compareExpressionToNullPtr(Result, UnOp->getSubExpr(), true);
218 
219         if (needsZeroComparison(UnOp->getSubExpr()))
220           return compareExpressionToZero(Result, UnOp->getSubExpr(), true);
221 
222         return replacementExpression(Result, false, UnOp->getSubExpr());
223       }
224     }
225 
226     if (needsNullPtrComparison(E))
227       return compareExpressionToNullPtr(Result, E, false);
228 
229     if (needsZeroComparison(E))
230       return compareExpressionToZero(Result, E, false);
231 
232     StringRef NegatedOperator;
233     const Expr *LHS = nullptr;
234     const Expr *RHS = nullptr;
235     if (const auto *BinOp = dyn_cast<BinaryOperator>(E)) {
236       NegatedOperator = negatedOperator(BinOp);
237       LHS = BinOp->getLHS();
238       RHS = BinOp->getRHS();
239     } else if (const auto *OpExpr = dyn_cast<CXXOperatorCallExpr>(E)) {
240       if (OpExpr->getNumArgs() == 2) {
241         NegatedOperator = negatedOperator(OpExpr);
242         LHS = OpExpr->getArg(0);
243         RHS = OpExpr->getArg(1);
244       }
245     }
246     if (!NegatedOperator.empty() && LHS && RHS)
247       return (asBool((getText(Result, *LHS) + " " + NegatedOperator + " " +
248                       getText(Result, *RHS))
249                          .str(),
250                      NeedsStaticCast));
251 
252     StringRef Text = getText(Result, *E);
253     if (!NeedsStaticCast && needsParensAfterUnaryNegation(E))
254       return ("!(" + Text + ")").str();
255 
256     if (needsNullPtrComparison(E))
257       return compareExpressionToNullPtr(Result, E, false);
258 
259     if (needsZeroComparison(E))
260       return compareExpressionToZero(Result, E, false);
261 
262     return ("!" + asBool(Text, NeedsStaticCast));
263   }
264 
265   if (const auto *UnOp = dyn_cast<UnaryOperator>(E)) {
266     if (UnOp->getOpcode() == UO_LNot) {
267       if (needsNullPtrComparison(UnOp->getSubExpr()))
268         return compareExpressionToNullPtr(Result, UnOp->getSubExpr(), false);
269 
270       if (needsZeroComparison(UnOp->getSubExpr()))
271         return compareExpressionToZero(Result, UnOp->getSubExpr(), false);
272     }
273   }
274 
275   if (needsNullPtrComparison(E))
276     return compareExpressionToNullPtr(Result, E, true);
277 
278   if (needsZeroComparison(E))
279     return compareExpressionToZero(Result, E, true);
280 
281   return asBool(getText(Result, *E), NeedsStaticCast);
282 }
283 
stmtReturnsBool(const ReturnStmt * Ret,bool Negated)284 const Expr *stmtReturnsBool(const ReturnStmt *Ret, bool Negated) {
285   if (const auto *Bool = dyn_cast<CXXBoolLiteralExpr>(Ret->getRetValue())) {
286     if (Bool->getValue() == !Negated)
287       return Bool;
288   }
289   if (const auto *Unary = dyn_cast<UnaryOperator>(Ret->getRetValue())) {
290     if (Unary->getOpcode() == UO_LNot) {
291       if (const auto *Bool =
292               dyn_cast<CXXBoolLiteralExpr>(Unary->getSubExpr())) {
293         if (Bool->getValue() == Negated)
294           return Bool;
295       }
296     }
297   }
298 
299   return nullptr;
300 }
301 
stmtReturnsBool(const IfStmt * IfRet,bool Negated)302 const Expr *stmtReturnsBool(const IfStmt *IfRet, bool Negated) {
303   if (IfRet->getElse() != nullptr)
304     return nullptr;
305 
306   if (const auto *Ret = dyn_cast<ReturnStmt>(IfRet->getThen()))
307     return stmtReturnsBool(Ret, Negated);
308 
309   if (const auto *Compound = dyn_cast<CompoundStmt>(IfRet->getThen())) {
310     if (Compound->size() == 1) {
311       if (const auto *CompoundRet = dyn_cast<ReturnStmt>(Compound->body_back()))
312         return stmtReturnsBool(CompoundRet, Negated);
313     }
314   }
315 
316   return nullptr;
317 }
318 
containsDiscardedTokens(const MatchFinder::MatchResult & Result,CharSourceRange CharRange)319 bool containsDiscardedTokens(const MatchFinder::MatchResult &Result,
320                              CharSourceRange CharRange) {
321   std::string ReplacementText =
322       Lexer::getSourceText(CharRange, *Result.SourceManager,
323                            Result.Context->getLangOpts())
324           .str();
325   Lexer Lex(CharRange.getBegin(), Result.Context->getLangOpts(),
326             ReplacementText.data(), ReplacementText.data(),
327             ReplacementText.data() + ReplacementText.size());
328   Lex.SetCommentRetentionState(true);
329 
330   Token Tok;
331   while (!Lex.LexFromRawLexer(Tok)) {
332     if (Tok.is(tok::TokenKind::comment) || Tok.is(tok::TokenKind::hash))
333       return true;
334   }
335 
336   return false;
337 }
338 
339 } // namespace
340 
341 class SimplifyBooleanExprCheck::Visitor : public RecursiveASTVisitor<Visitor> {
342  public:
Visitor(SimplifyBooleanExprCheck * Check,const MatchFinder::MatchResult & Result)343   Visitor(SimplifyBooleanExprCheck *Check,
344           const MatchFinder::MatchResult &Result)
345       : Check(Check), Result(Result) {}
346 
VisitBinaryOperator(BinaryOperator * Op)347   bool VisitBinaryOperator(BinaryOperator *Op) {
348     Check->reportBinOp(Result, Op);
349     return true;
350   }
351 
352  private:
353   SimplifyBooleanExprCheck *Check;
354   const MatchFinder::MatchResult &Result;
355 };
356 
SimplifyBooleanExprCheck(StringRef Name,ClangTidyContext * Context)357 SimplifyBooleanExprCheck::SimplifyBooleanExprCheck(StringRef Name,
358                                                    ClangTidyContext *Context)
359     : ClangTidyCheck(Name, Context),
360       ChainedConditionalReturn(Options.get("ChainedConditionalReturn", false)),
361       ChainedConditionalAssignment(
362           Options.get("ChainedConditionalAssignment", false)) {}
363 
containsBoolLiteral(const Expr * E)364 bool containsBoolLiteral(const Expr *E) {
365   if (!E)
366     return false;
367   E = E->IgnoreParenImpCasts();
368   if (isa<CXXBoolLiteralExpr>(E))
369     return true;
370   if (const auto *BinOp = dyn_cast<BinaryOperator>(E))
371     return containsBoolLiteral(BinOp->getLHS()) ||
372            containsBoolLiteral(BinOp->getRHS());
373   if (const auto *UnaryOp = dyn_cast<UnaryOperator>(E))
374     return containsBoolLiteral(UnaryOp->getSubExpr());
375   return false;
376 }
377 
reportBinOp(const MatchFinder::MatchResult & Result,const BinaryOperator * Op)378 void SimplifyBooleanExprCheck::reportBinOp(
379     const MatchFinder::MatchResult &Result, const BinaryOperator *Op) {
380   const auto *LHS = Op->getLHS()->IgnoreParenImpCasts();
381   const auto *RHS = Op->getRHS()->IgnoreParenImpCasts();
382 
383   const CXXBoolLiteralExpr *Bool;
384   const Expr *Other = nullptr;
385   if ((Bool = dyn_cast<CXXBoolLiteralExpr>(LHS)))
386     Other = RHS;
387   else if ((Bool = dyn_cast<CXXBoolLiteralExpr>(RHS)))
388     Other = LHS;
389   else
390     return;
391 
392   if (Bool->getBeginLoc().isMacroID())
393     return;
394 
395   // FIXME: why do we need this?
396   if (!isa<CXXBoolLiteralExpr>(Other) && containsBoolLiteral(Other))
397     return;
398 
399   bool BoolValue = Bool->getValue();
400 
401   auto replaceWithExpression = [this, &Result, LHS, RHS, Bool](
402                                    const Expr *ReplaceWith, bool Negated) {
403     std::string Replacement =
404         replacementExpression(Result, Negated, ReplaceWith);
405     SourceRange Range(LHS->getBeginLoc(), RHS->getEndLoc());
406     issueDiag(Result, Bool->getBeginLoc(), SimplifyOperatorDiagnostic, Range,
407               Replacement);
408   };
409 
410   switch (Op->getOpcode()) {
411     case BO_LAnd:
412       if (BoolValue) {
413         // expr && true -> expr
414         replaceWithExpression(Other, /*Negated=*/false);
415       } else {
416         // expr && false -> false
417         replaceWithExpression(Bool, /*Negated=*/false);
418       }
419       break;
420     case BO_LOr:
421       if (BoolValue) {
422         // expr || true -> true
423         replaceWithExpression(Bool, /*Negated=*/false);
424       } else {
425         // expr || false -> expr
426         replaceWithExpression(Other, /*Negated=*/false);
427       }
428       break;
429     case BO_EQ:
430       // expr == true -> expr, expr == false -> !expr
431       replaceWithExpression(Other, /*Negated=*/!BoolValue);
432       break;
433     case BO_NE:
434       // expr != true -> !expr, expr != false -> expr
435       replaceWithExpression(Other, /*Negated=*/BoolValue);
436       break;
437     default:
438       break;
439   }
440 }
441 
matchBoolCondition(MatchFinder * Finder,bool Value,StringRef BooleanId)442 void SimplifyBooleanExprCheck::matchBoolCondition(MatchFinder *Finder,
443                                                   bool Value,
444                                                   StringRef BooleanId) {
445   Finder->addMatcher(
446       ifStmt(unless(isInTemplateInstantiation()),
447              hasCondition(literalOrNegatedBool(Value).bind(BooleanId)))
448           .bind(IfStmtId),
449       this);
450 }
451 
matchTernaryResult(MatchFinder * Finder,bool Value,StringRef TernaryId)452 void SimplifyBooleanExprCheck::matchTernaryResult(MatchFinder *Finder,
453                                                   bool Value,
454                                                   StringRef TernaryId) {
455   Finder->addMatcher(
456       conditionalOperator(unless(isInTemplateInstantiation()),
457                           hasTrueExpression(literalOrNegatedBool(Value)),
458                           hasFalseExpression(literalOrNegatedBool(!Value)))
459           .bind(TernaryId),
460       this);
461 }
462 
matchIfReturnsBool(MatchFinder * Finder,bool Value,StringRef Id)463 void SimplifyBooleanExprCheck::matchIfReturnsBool(MatchFinder *Finder,
464                                                   bool Value, StringRef Id) {
465   if (ChainedConditionalReturn)
466     Finder->addMatcher(ifStmt(unless(isInTemplateInstantiation()),
467                               hasThen(returnsBool(Value, ThenLiteralId)),
468                               hasElse(returnsBool(!Value)))
469                            .bind(Id),
470                        this);
471   else
472     Finder->addMatcher(ifStmt(unless(isInTemplateInstantiation()),
473                               unless(hasParent(ifStmt())),
474                               hasThen(returnsBool(Value, ThenLiteralId)),
475                               hasElse(returnsBool(!Value)))
476                            .bind(Id),
477                        this);
478 }
479 
matchIfAssignsBool(MatchFinder * Finder,bool Value,StringRef Id)480 void SimplifyBooleanExprCheck::matchIfAssignsBool(MatchFinder *Finder,
481                                                   bool Value, StringRef Id) {
482   auto VarAssign = declRefExpr(hasDeclaration(decl().bind(IfAssignVarId)));
483   auto VarRef = declRefExpr(hasDeclaration(equalsBoundNode(IfAssignVarId)));
484   auto MemAssign = memberExpr(hasDeclaration(decl().bind(IfAssignVarId)));
485   auto MemRef = memberExpr(hasDeclaration(equalsBoundNode(IfAssignVarId)));
486   auto SimpleThen =
487       binaryOperator(hasOperatorName("="), hasLHS(anyOf(VarAssign, MemAssign)),
488                      hasLHS(expr().bind(IfAssignVariableId)),
489                      hasRHS(literalOrNegatedBool(Value).bind(IfAssignLocId)));
490   auto Then = anyOf(SimpleThen, compoundStmt(statementCountIs(1),
491                                              hasAnySubstatement(SimpleThen)));
492   auto SimpleElse =
493       binaryOperator(hasOperatorName("="), hasLHS(anyOf(VarRef, MemRef)),
494                      hasRHS(literalOrNegatedBool(!Value)));
495   auto Else = anyOf(SimpleElse, compoundStmt(statementCountIs(1),
496                                              hasAnySubstatement(SimpleElse)));
497   if (ChainedConditionalAssignment)
498     Finder->addMatcher(ifStmt(unless(isInTemplateInstantiation()),
499                               hasThen(Then), hasElse(Else))
500                            .bind(Id),
501                        this);
502   else
503     Finder->addMatcher(ifStmt(unless(isInTemplateInstantiation()),
504                               unless(hasParent(ifStmt())), hasThen(Then),
505                               hasElse(Else))
506                            .bind(Id),
507                        this);
508 }
509 
matchCompoundIfReturnsBool(MatchFinder * Finder,bool Value,StringRef Id)510 void SimplifyBooleanExprCheck::matchCompoundIfReturnsBool(MatchFinder *Finder,
511                                                           bool Value,
512                                                           StringRef Id) {
513   Finder->addMatcher(
514       compoundStmt(
515           unless(isInTemplateInstantiation()),
516           hasAnySubstatement(
517               ifStmt(hasThen(returnsBool(Value)), unless(hasElse(stmt())))),
518           hasAnySubstatement(returnStmt(has(ignoringParenImpCasts(
519                                             literalOrNegatedBool(!Value))))
520                                  .bind(CompoundReturnId)))
521           .bind(Id),
522       this);
523 }
524 
storeOptions(ClangTidyOptions::OptionMap & Opts)525 void SimplifyBooleanExprCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
526   Options.store(Opts, "ChainedConditionalReturn", ChainedConditionalReturn);
527   Options.store(Opts, "ChainedConditionalAssignment",
528                 ChainedConditionalAssignment);
529 }
530 
registerMatchers(MatchFinder * Finder)531 void SimplifyBooleanExprCheck::registerMatchers(MatchFinder *Finder) {
532   Finder->addMatcher(translationUnitDecl().bind("top"), this);
533 
534   matchBoolCondition(Finder, true, ConditionThenStmtId);
535   matchBoolCondition(Finder, false, ConditionElseStmtId);
536 
537   matchTernaryResult(Finder, true, TernaryId);
538   matchTernaryResult(Finder, false, TernaryNegatedId);
539 
540   matchIfReturnsBool(Finder, true, IfReturnsBoolId);
541   matchIfReturnsBool(Finder, false, IfReturnsNotBoolId);
542 
543   matchIfAssignsBool(Finder, true, IfAssignBoolId);
544   matchIfAssignsBool(Finder, false, IfAssignNotBoolId);
545 
546   matchCompoundIfReturnsBool(Finder, true, CompoundBoolId);
547   matchCompoundIfReturnsBool(Finder, false, CompoundNotBoolId);
548 }
549 
check(const MatchFinder::MatchResult & Result)550 void SimplifyBooleanExprCheck::check(const MatchFinder::MatchResult &Result) {
551   if (Result.Nodes.getNodeAs<TranslationUnitDecl>("top"))
552     Visitor(this, Result).TraverseAST(*Result.Context);
553   else if (const Expr *TrueConditionRemoved =
554                getBoolLiteral(Result, ConditionThenStmtId))
555     replaceWithThenStatement(Result, TrueConditionRemoved);
556   else if (const Expr *FalseConditionRemoved =
557                getBoolLiteral(Result, ConditionElseStmtId))
558     replaceWithElseStatement(Result, FalseConditionRemoved);
559   else if (const auto *Ternary =
560                Result.Nodes.getNodeAs<ConditionalOperator>(TernaryId))
561     replaceWithCondition(Result, Ternary);
562   else if (const auto *TernaryNegated =
563                Result.Nodes.getNodeAs<ConditionalOperator>(TernaryNegatedId))
564     replaceWithCondition(Result, TernaryNegated, true);
565   else if (const auto *If = Result.Nodes.getNodeAs<IfStmt>(IfReturnsBoolId))
566     replaceWithReturnCondition(Result, If);
567   else if (const auto *IfNot =
568                Result.Nodes.getNodeAs<IfStmt>(IfReturnsNotBoolId))
569     replaceWithReturnCondition(Result, IfNot, true);
570   else if (const auto *IfAssign =
571                Result.Nodes.getNodeAs<IfStmt>(IfAssignBoolId))
572     replaceWithAssignment(Result, IfAssign);
573   else if (const auto *IfAssignNot =
574                Result.Nodes.getNodeAs<IfStmt>(IfAssignNotBoolId))
575     replaceWithAssignment(Result, IfAssignNot, true);
576   else if (const auto *Compound =
577                Result.Nodes.getNodeAs<CompoundStmt>(CompoundBoolId))
578     replaceCompoundReturnWithCondition(Result, Compound);
579   else if (const auto *Compound =
580                Result.Nodes.getNodeAs<CompoundStmt>(CompoundNotBoolId))
581     replaceCompoundReturnWithCondition(Result, Compound, true);
582 }
583 
issueDiag(const ast_matchers::MatchFinder::MatchResult & Result,SourceLocation Loc,StringRef Description,SourceRange ReplacementRange,StringRef Replacement)584 void SimplifyBooleanExprCheck::issueDiag(
585     const ast_matchers::MatchFinder::MatchResult &Result, SourceLocation Loc,
586     StringRef Description, SourceRange ReplacementRange,
587     StringRef Replacement) {
588   CharSourceRange CharRange =
589       Lexer::makeFileCharRange(CharSourceRange::getTokenRange(ReplacementRange),
590                                *Result.SourceManager, getLangOpts());
591 
592   DiagnosticBuilder Diag = diag(Loc, Description);
593   if (!containsDiscardedTokens(Result, CharRange))
594     Diag << FixItHint::CreateReplacement(CharRange, Replacement);
595 }
596 
replaceWithThenStatement(const MatchFinder::MatchResult & Result,const Expr * TrueConditionRemoved)597 void SimplifyBooleanExprCheck::replaceWithThenStatement(
598     const MatchFinder::MatchResult &Result, const Expr *TrueConditionRemoved) {
599   const auto *IfStatement = Result.Nodes.getNodeAs<IfStmt>(IfStmtId);
600   issueDiag(Result, TrueConditionRemoved->getBeginLoc(),
601             SimplifyConditionDiagnostic, IfStatement->getSourceRange(),
602             getText(Result, *IfStatement->getThen()));
603 }
604 
replaceWithElseStatement(const MatchFinder::MatchResult & Result,const Expr * FalseConditionRemoved)605 void SimplifyBooleanExprCheck::replaceWithElseStatement(
606     const MatchFinder::MatchResult &Result, const Expr *FalseConditionRemoved) {
607   const auto *IfStatement = Result.Nodes.getNodeAs<IfStmt>(IfStmtId);
608   const Stmt *ElseStatement = IfStatement->getElse();
609   issueDiag(Result, FalseConditionRemoved->getBeginLoc(),
610             SimplifyConditionDiagnostic, IfStatement->getSourceRange(),
611             ElseStatement ? getText(Result, *ElseStatement) : "");
612 }
613 
replaceWithCondition(const MatchFinder::MatchResult & Result,const ConditionalOperator * Ternary,bool Negated)614 void SimplifyBooleanExprCheck::replaceWithCondition(
615     const MatchFinder::MatchResult &Result, const ConditionalOperator *Ternary,
616     bool Negated) {
617   std::string Replacement =
618       replacementExpression(Result, Negated, Ternary->getCond());
619   issueDiag(Result, Ternary->getTrueExpr()->getBeginLoc(),
620             "redundant boolean literal in ternary expression result",
621             Ternary->getSourceRange(), Replacement);
622 }
623 
replaceWithReturnCondition(const MatchFinder::MatchResult & Result,const IfStmt * If,bool Negated)624 void SimplifyBooleanExprCheck::replaceWithReturnCondition(
625     const MatchFinder::MatchResult &Result, const IfStmt *If, bool Negated) {
626   StringRef Terminator = isa<CompoundStmt>(If->getElse()) ? ";" : "";
627   std::string Condition = replacementExpression(Result, Negated, If->getCond());
628   std::string Replacement = ("return " + Condition + Terminator).str();
629   SourceLocation Start =
630       Result.Nodes.getNodeAs<CXXBoolLiteralExpr>(ThenLiteralId)->getBeginLoc();
631   issueDiag(Result, Start, SimplifyConditionalReturnDiagnostic,
632             If->getSourceRange(), Replacement);
633 }
634 
replaceCompoundReturnWithCondition(const MatchFinder::MatchResult & Result,const CompoundStmt * Compound,bool Negated)635 void SimplifyBooleanExprCheck::replaceCompoundReturnWithCondition(
636     const MatchFinder::MatchResult &Result, const CompoundStmt *Compound,
637     bool Negated) {
638   const auto *Ret = Result.Nodes.getNodeAs<ReturnStmt>(CompoundReturnId);
639 
640   // The body shouldn't be empty because the matcher ensures that it must
641   // contain at least two statements:
642   // 1) A `return` statement returning a boolean literal `false` or `true`
643   // 2) An `if` statement with no `else` clause that consists of a single
644   //    `return` statement returning the opposite boolean literal `true` or
645   //    `false`.
646   assert(Compound->size() >= 2);
647   const IfStmt *BeforeIf = nullptr;
648   CompoundStmt::const_body_iterator Current = Compound->body_begin();
649   CompoundStmt::const_body_iterator After = Compound->body_begin();
650   for (++After; After != Compound->body_end() && *Current != Ret;
651        ++Current, ++After) {
652     if (const auto *If = dyn_cast<IfStmt>(*Current)) {
653       if (const Expr *Lit = stmtReturnsBool(If, Negated)) {
654         if (*After == Ret) {
655           if (!ChainedConditionalReturn && BeforeIf)
656             continue;
657 
658           const Expr *Condition = If->getCond();
659           std::string Replacement =
660               "return " + replacementExpression(Result, Negated, Condition);
661           issueDiag(
662               Result, Lit->getBeginLoc(), SimplifyConditionalReturnDiagnostic,
663               SourceRange(If->getBeginLoc(), Ret->getEndLoc()), Replacement);
664           return;
665         }
666 
667         BeforeIf = If;
668       }
669     } else {
670       BeforeIf = nullptr;
671     }
672   }
673 }
674 
replaceWithAssignment(const MatchFinder::MatchResult & Result,const IfStmt * IfAssign,bool Negated)675 void SimplifyBooleanExprCheck::replaceWithAssignment(
676     const MatchFinder::MatchResult &Result, const IfStmt *IfAssign,
677     bool Negated) {
678   SourceRange Range = IfAssign->getSourceRange();
679   StringRef VariableName =
680       getText(Result, *Result.Nodes.getNodeAs<Expr>(IfAssignVariableId));
681   StringRef Terminator = isa<CompoundStmt>(IfAssign->getElse()) ? ";" : "";
682   std::string Condition =
683       replacementExpression(Result, Negated, IfAssign->getCond());
684   std::string Replacement =
685       (VariableName + " = " + Condition + Terminator).str();
686   SourceLocation Location =
687       Result.Nodes.getNodeAs<CXXBoolLiteralExpr>(IfAssignLocId)->getBeginLoc();
688   issueDiag(Result, Location,
689             "redundant boolean literal in conditional assignment", Range,
690             Replacement);
691 }
692 
693 } // namespace readability
694 } // namespace tidy
695 } // namespace clang
696