1 //===--- PopulateSwitch.cpp --------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Tweak that populates an empty switch statement of an enumeration type with
10 // all of the enumerators of that type.
11 //
12 // Before:
13 //   enum Color { RED, GREEN, BLUE };
14 //
15 //   void f(Color color) {
16 //     switch (color) {}
17 //   }
18 //
19 // After:
20 //   enum Color { RED, GREEN, BLUE };
21 //
22 //   void f(Color color) {
23 //     switch (color) {
24 //     case RED:
25 //     case GREEN:
26 //     case BLUE:
27 //       break;
28 //     }
29 //   }
30 //
31 //===----------------------------------------------------------------------===//
32 
33 #include "AST.h"
34 #include "Selection.h"
35 #include "refactor/Tweak.h"
36 #include "support/Logger.h"
37 #include "clang/AST/Decl.h"
38 #include "clang/AST/Stmt.h"
39 #include "clang/AST/Type.h"
40 #include "clang/Basic/SourceLocation.h"
41 #include "clang/Basic/SourceManager.h"
42 #include "clang/Tooling/Core/Replacement.h"
43 #include "llvm/ADT/MapVector.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include <cassert>
46 #include <string>
47 
48 namespace clang {
49 namespace clangd {
50 namespace {
51 class PopulateSwitch : public Tweak {
52   const char *id() const override;
53   bool prepare(const Selection &Sel) override;
54   Expected<Effect> apply(const Selection &Sel) override;
title() const55   std::string title() const override { return "Populate switch"; }
kind() const56   llvm::StringLiteral kind() const override {
57     return CodeAction::QUICKFIX_KIND;
58   }
59 
60 private:
61   class ExpectedCase {
62   public:
ExpectedCase(const EnumConstantDecl * Decl)63     ExpectedCase(const EnumConstantDecl *Decl) : Data(Decl, false) {}
isCovered() const64     bool isCovered() const { return Data.getInt(); }
setCovered(bool Val=true)65     void setCovered(bool Val = true) { Data.setInt(Val); }
getEnumConstant() const66     const EnumConstantDecl *getEnumConstant() const {
67       return Data.getPointer();
68     }
69 
70   private:
71     llvm::PointerIntPair<const EnumConstantDecl *, 1, bool> Data;
72   };
73 
74   const DeclContext *DeclCtx = nullptr;
75   const SwitchStmt *Switch = nullptr;
76   const CompoundStmt *Body = nullptr;
77   const EnumType *EnumT = nullptr;
78   const EnumDecl *EnumD = nullptr;
79   // Maps the Enum values to the EnumConstantDecl and a bool signifying if its
80   // covered in the switch.
81   llvm::MapVector<llvm::APSInt, ExpectedCase> ExpectedCases;
82 };
83 
REGISTER_TWEAK(PopulateSwitch)84 REGISTER_TWEAK(PopulateSwitch)
85 
86 bool PopulateSwitch::prepare(const Selection &Sel) {
87   const SelectionTree::Node *CA = Sel.ASTSelection.commonAncestor();
88   if (!CA)
89     return false;
90 
91   const Stmt *CAStmt = CA->ASTNode.get<Stmt>();
92   if (!CAStmt)
93     return false;
94 
95   // Go up a level if we see a compound statement.
96   // switch (value) {}
97   //                ^^
98   if (isa<CompoundStmt>(CAStmt)) {
99     CA = CA->Parent;
100     if (!CA)
101       return false;
102 
103     CAStmt = CA->ASTNode.get<Stmt>();
104     if (!CAStmt)
105       return false;
106   }
107 
108   DeclCtx = &CA->getDeclContext();
109   Switch = dyn_cast<SwitchStmt>(CAStmt);
110   if (!Switch)
111     return false;
112 
113   Body = dyn_cast<CompoundStmt>(Switch->getBody());
114   if (!Body)
115     return false;
116 
117   const Expr *Cond = Switch->getCond();
118   if (!Cond)
119     return false;
120 
121   // Ignore implicit casts, since enums implicitly cast to integer types.
122   Cond = Cond->IgnoreParenImpCasts();
123 
124   EnumT = Cond->getType()->getAsAdjusted<EnumType>();
125   if (!EnumT)
126     return false;
127 
128   EnumD = EnumT->getDecl();
129   if (!EnumD || EnumD->isDependentType())
130     return false;
131 
132   // We trigger if there are any values in the enum that aren't covered by the
133   // switch.
134 
135   ASTContext &Ctx = Sel.AST->getASTContext();
136 
137   unsigned EnumIntWidth = Ctx.getIntWidth(QualType(EnumT, 0));
138   bool EnumIsSigned = EnumT->isSignedIntegerOrEnumerationType();
139 
140   auto Normalize = [&](llvm::APSInt Val) {
141     Val = Val.extOrTrunc(EnumIntWidth);
142     Val.setIsSigned(EnumIsSigned);
143     return Val;
144   };
145 
146   for (auto *EnumConstant : EnumD->enumerators()) {
147     ExpectedCases.insert(
148         std::make_pair(Normalize(EnumConstant->getInitVal()), EnumConstant));
149   }
150 
151   for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); CaseList;
152        CaseList = CaseList->getNextSwitchCase()) {
153     // Default likely intends to cover cases we'd insert.
154     if (isa<DefaultStmt>(CaseList))
155       return false;
156 
157     const CaseStmt *CS = cast<CaseStmt>(CaseList);
158 
159     // GNU range cases are rare, we don't support them.
160     if (CS->caseStmtIsGNURange())
161       return false;
162 
163     // Case expression is not a constant expression or is value-dependent,
164     // so we may not be able to work out which cases are covered.
165     const ConstantExpr *CE = dyn_cast<ConstantExpr>(CS->getLHS());
166     if (!CE || CE->isValueDependent())
167       return false;
168 
169     // Unsure if this case could ever come up, but prevents an unreachable
170     // executing in getResultAsAPSInt.
171     if (CE->getResultStorageKind() == ConstantExpr::RSK_None)
172       return false;
173     auto Iter = ExpectedCases.find(Normalize(CE->getResultAsAPSInt()));
174     if (Iter != ExpectedCases.end())
175       Iter->second.setCovered();
176   }
177 
178   return !llvm::all_of(ExpectedCases,
179                        [](auto &Pair) { return Pair.second.isCovered(); });
180 }
181 
apply(const Selection & Sel)182 Expected<Tweak::Effect> PopulateSwitch::apply(const Selection &Sel) {
183   ASTContext &Ctx = Sel.AST->getASTContext();
184 
185   SourceLocation Loc = Body->getRBracLoc();
186   ASTContext &DeclASTCtx = DeclCtx->getParentASTContext();
187 
188   llvm::SmallString<256> Text;
189   for (auto &EnumConstant : ExpectedCases) {
190     // Skip any enum constants already covered
191     if (EnumConstant.second.isCovered())
192       continue;
193 
194     Text.append({"case ", getQualification(DeclASTCtx, DeclCtx, Loc, EnumD)});
195     if (EnumD->isScoped())
196       Text.append({EnumD->getName(), "::"});
197     Text.append({EnumConstant.second.getEnumConstant()->getName(), ":"});
198   }
199 
200   assert(!Text.empty() && "No enumerators to insert!");
201   Text += "break;";
202 
203   const SourceManager &SM = Ctx.getSourceManager();
204   return Effect::mainFileEdit(
205       SM, tooling::Replacements(tooling::Replacement(SM, Loc, 0, Text)));
206 }
207 } // namespace
208 } // namespace clangd
209 } // namespace clang
210