1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLSwitchStatement.h"
9 
10 #include <forward_list>
11 
12 #include "include/private/SkTHash.h"
13 #include "src/sksl/SkSLAnalysis.h"
14 #include "src/sksl/SkSLConstantFolder.h"
15 #include "src/sksl/SkSLContext.h"
16 #include "src/sksl/SkSLProgramSettings.h"
17 #include "src/sksl/ir/SkSLBlock.h"
18 #include "src/sksl/ir/SkSLNop.h"
19 #include "src/sksl/ir/SkSLSymbolTable.h"
20 #include "src/sksl/ir/SkSLType.h"
21 
22 namespace SkSL {
23 
clone() const24 std::unique_ptr<Statement> SwitchStatement::clone() const {
25     StatementArray cases;
26     cases.reserve_back(this->cases().size());
27     for (const std::unique_ptr<Statement>& stmt : this->cases()) {
28         cases.push_back(stmt->clone());
29     }
30     return std::make_unique<SwitchStatement>(fOffset,
31                                              this->isStatic(),
32                                              this->value()->clone(),
33                                              std::move(cases),
34                                              SymbolTable::WrapIfBuiltin(this->symbols()));
35 }
36 
description() const37 String SwitchStatement::description() const {
38     String result;
39     if (this->isStatic()) {
40         result += "@";
41     }
42     result += String::printf("switch (%s) {\n", this->value()->description().c_str());
43     for (const auto& c : this->cases()) {
44         result += c->description();
45     }
46     result += "}";
47     return result;
48 }
49 
find_duplicate_case_values(const StatementArray & cases)50 static std::forward_list<const SwitchCase*> find_duplicate_case_values(
51         const StatementArray& cases) {
52     std::forward_list<const SwitchCase*> duplicateCases;
53     SkTHashSet<SKSL_INT> intValues;
54     bool foundDefault = false;
55 
56     for (const std::unique_ptr<Statement>& stmt : cases) {
57         const SwitchCase* sc = &stmt->as<SwitchCase>();
58         const std::unique_ptr<Expression>& valueExpr = sc->value();
59 
60         // A null case-value indicates the `default` switch-case.
61         if (!valueExpr) {
62             if (foundDefault) {
63                 duplicateCases.push_front(sc);
64                 continue;
65             }
66             foundDefault = true;
67             continue;
68         }
69 
70         // GetConstantInt already succeeded when the SwitchCase was first assembled, so it should
71         // succeed this time too.
72         SKSL_INT intValue = 0;
73         SkAssertResult(ConstantFolder::GetConstantInt(*valueExpr, &intValue));
74         if (intValues.contains(intValue)) {
75             duplicateCases.push_front(sc);
76             continue;
77         }
78         intValues.add(intValue);
79     }
80 
81     return duplicateCases;
82 }
83 
move_all_but_break(std::unique_ptr<Statement> & stmt,StatementArray * target)84 static void move_all_but_break(std::unique_ptr<Statement>& stmt, StatementArray* target) {
85     switch (stmt->kind()) {
86         case Statement::Kind::kBlock: {
87             // Recurse into the block.
88             Block& block = stmt->as<Block>();
89 
90             StatementArray blockStmts;
91             blockStmts.reserve_back(block.children().size());
92             for (std::unique_ptr<Statement>& stmt : block.children()) {
93                 move_all_but_break(stmt, &blockStmts);
94             }
95 
96             target->push_back(Block::Make(block.fOffset, std::move(blockStmts),
97                                           block.symbolTable(), block.isScope()));
98             break;
99         }
100 
101         case Statement::Kind::kBreak:
102             // Do not append a break to the target.
103             break;
104 
105         default:
106             // Append normal statements to the target.
107             target->push_back(std::move(stmt));
108             break;
109     }
110 }
111 
BlockForCase(StatementArray * cases,SwitchCase * caseToCapture,std::shared_ptr<SymbolTable> symbolTable)112 std::unique_ptr<Statement> SwitchStatement::BlockForCase(StatementArray* cases,
113                                                          SwitchCase* caseToCapture,
114                                                          std::shared_ptr<SymbolTable> symbolTable) {
115     // We have to be careful to not move any of the pointers until after we're sure we're going to
116     // succeed, so before we make any changes at all, we check the switch-cases to decide on a plan
117     // of action. First, find the switch-case we are interested in.
118     auto iter = cases->begin();
119     for (; iter != cases->end(); ++iter) {
120         const SwitchCase& sc = (*iter)->as<SwitchCase>();
121         if (&sc == caseToCapture) {
122             break;
123         }
124     }
125 
126     // Next, walk forward through the rest of the switch. If we find a conditional break, we're
127     // stuck and can't simplify at all. If we find an unconditional break, we have a range of
128     // statements that we can use for simplification.
129     auto startIter = iter;
130     Statement* stripBreakStmt = nullptr;
131     for (; iter != cases->end(); ++iter) {
132         std::unique_ptr<Statement>& stmt = (*iter)->as<SwitchCase>().statement();
133         if (Analysis::SwitchCaseContainsConditionalExit(*stmt)) {
134             // We can't reduce switch-cases to a block when they have conditional exits.
135             return nullptr;
136         }
137         if (Analysis::SwitchCaseContainsUnconditionalExit(*stmt)) {
138             // We found an unconditional exit. We can use this block, but we'll need to strip
139             // out the break statement if there is one.
140             stripBreakStmt = stmt.get();
141             break;
142         }
143     }
144 
145     // We fell off the bottom of the switch or encountered a break. We know the range of statements
146     // that we need to move over, and we know it's safe to do so.
147     StatementArray caseStmts;
148     caseStmts.reserve_back(std::distance(startIter, iter) + 1);
149 
150     // We can move over most of the statements as-is.
151     while (startIter != iter) {
152         caseStmts.push_back(std::move((*startIter)->as<SwitchCase>().statement()));
153         ++startIter;
154     }
155 
156     // If we found an unconditional break at the end, we need to move what we can while avoiding
157     // that break.
158     if (stripBreakStmt != nullptr) {
159         SkASSERT((*startIter)->as<SwitchCase>().statement().get() == stripBreakStmt);
160         move_all_but_break((*startIter)->as<SwitchCase>().statement(), &caseStmts);
161     }
162 
163     // Return our newly-synthesized block.
164     return Block::Make(caseToCapture->fOffset, std::move(caseStmts), std::move(symbolTable));
165 }
166 
Convert(const Context & context,int offset,bool isStatic,std::unique_ptr<Expression> value,ExpressionArray caseValues,StatementArray caseStatements,std::shared_ptr<SymbolTable> symbolTable)167 std::unique_ptr<Statement> SwitchStatement::Convert(const Context& context,
168                                                     int offset,
169                                                     bool isStatic,
170                                                     std::unique_ptr<Expression> value,
171                                                     ExpressionArray caseValues,
172                                                     StatementArray caseStatements,
173                                                     std::shared_ptr<SymbolTable> symbolTable) {
174     SkASSERT(caseValues.size() == caseStatements.size());
175     if (context.fConfig->strictES2Mode()) {
176         context.fErrors.error(offset, "switch statements are not supported");
177         return nullptr;
178     }
179 
180     if (!value->type().isEnum()) {
181         value = context.fTypes.fInt->coerceExpression(std::move(value), context);
182         if (!value) {
183             return nullptr;
184         }
185     }
186 
187     StatementArray cases;
188     for (int i = 0; i < caseValues.count(); ++i) {
189         int caseOffset;
190         std::unique_ptr<Expression> caseValue;
191         if (caseValues[i]) {
192             caseOffset = caseValues[i]->fOffset;
193 
194             // Case values must be the same type as the switch value--`int` or a particular enum.
195             caseValue = value->type().coerceExpression(std::move(caseValues[i]), context);
196             if (!caseValue) {
197                 return nullptr;
198             }
199             // Case values must be a literal integer or a `const int` variable reference.
200             SKSL_INT intValue;
201             if (!ConstantFolder::GetConstantInt(*caseValue, &intValue)) {
202                 context.fErrors.error(caseValue->fOffset, "case value must be a constant integer");
203                 return nullptr;
204             }
205         } else {
206             // The null case-expression corresponds to `default:`.
207             caseOffset = offset;
208         }
209         cases.push_back(std::make_unique<SwitchCase>(caseOffset, std::move(caseValue),
210                                                      std::move(caseStatements[i])));
211     }
212 
213     // Detect duplicate `case` labels and report an error.
214     // (Using forward_list here to optimize for the common case of no results.)
215     std::forward_list<const SwitchCase*> duplicateCases = find_duplicate_case_values(cases);
216     if (!duplicateCases.empty()) {
217         duplicateCases.reverse();
218         for (const SwitchCase* sc : duplicateCases) {
219             if (sc->value() != nullptr) {
220                 context.fErrors.error(sc->fOffset,
221                                       "duplicate case value '" + sc->value()->description() + "'");
222             } else {
223                 context.fErrors.error(sc->fOffset, "duplicate default case");
224             }
225         }
226         return nullptr;
227     }
228 
229     return SwitchStatement::Make(context, offset, isStatic, std::move(value), std::move(cases),
230                                  std::move(symbolTable));
231 }
232 
Make(const Context & context,int offset,bool isStatic,std::unique_ptr<Expression> value,StatementArray cases,std::shared_ptr<SymbolTable> symbolTable)233 std::unique_ptr<Statement> SwitchStatement::Make(const Context& context,
234                                                  int offset,
235                                                  bool isStatic,
236                                                  std::unique_ptr<Expression> value,
237                                                  StatementArray cases,
238                                                  std::shared_ptr<SymbolTable> symbolTable) {
239     // Confirm that every statement in `cases` is a SwitchCase.
240     SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
241         return stmt->is<SwitchCase>();
242     }));
243 
244     // Confirm that every switch-case has been coerced to the proper type.
245     SkASSERT(std::all_of(cases.begin(), cases.end(), [&](const std::unique_ptr<Statement>& stmt) {
246         return !stmt->as<SwitchCase>().value() ||  // `default` case has a null value
247                value->type() == stmt->as<SwitchCase>().value()->type();
248     }));
249 
250     // Confirm that every switch-case value is unique.
251     SkASSERT(find_duplicate_case_values(cases).empty());
252 
253     // Flatten @switch statements.
254     if (isStatic || context.fConfig->fSettings.fOptimize) {
255         SKSL_INT switchValue;
256         if (ConstantFolder::GetConstantInt(*value, &switchValue)) {
257             SwitchCase* defaultCase = nullptr;
258             SwitchCase* matchingCase = nullptr;
259             for (const std::unique_ptr<Statement>& stmt : cases) {
260                 SwitchCase& sc = stmt->as<SwitchCase>();
261                 if (!sc.value()) {
262                     defaultCase = &sc;
263                     continue;
264                 }
265 
266                 SKSL_INT caseValue;
267                 SkAssertResult(ConstantFolder::GetConstantInt(*sc.value(), &caseValue));
268                 if (caseValue == switchValue) {
269                     matchingCase = &sc;
270                     break;
271                 }
272             }
273 
274             if (!matchingCase) {
275                 // No case value matches the switch value.
276                 if (!defaultCase) {
277                     // No default switch-case exists; the switch had no effect.
278                     // We can eliminate the entire switch!
279                     return Nop::Make();
280                 }
281                 // We had a default case; that's what we matched with.
282                 matchingCase = defaultCase;
283             }
284 
285             // Convert the switch-case that we matched with into a block.
286             std::unique_ptr<Statement> newBlock = BlockForCase(&cases, matchingCase, symbolTable);
287             if (newBlock) {
288                 return newBlock;
289             }
290 
291             // Report an error if this was a static switch and BlockForCase failed us.
292             if (isStatic && !context.fConfig->fSettings.fPermitInvalidStaticTests) {
293                 context.fErrors.error(value->fOffset,
294                                       "static switch contains non-static conditional exit");
295                 return nullptr;
296             }
297         }
298     }
299 
300     // The switch couldn't be optimized away; emit it normally.
301     return std::make_unique<SwitchStatement>(offset, isStatic, std::move(value), std::move(cases),
302                                              std::move(symbolTable));
303 }
304 
305 }  // namespace SkSL
306