1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // UnfoldShortCircuitToIf is an AST traverser to convert short-circuiting operators to if-else
7 // statements.
8 // The results are assigned to s# temporaries, which are used by the main translator instead of
9 // the original expression.
10 //
11 
12 #include "compiler/translator/tree_ops/d3d/UnfoldShortCircuitToIf.h"
13 
14 #include "compiler/translator/StaticType.h"
15 #include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 
19 namespace sh
20 {
21 
22 namespace
23 {
24 
25 // Traverser that unfolds one short-circuiting operation at a time.
26 class UnfoldShortCircuitTraverser : public TIntermTraverser
27 {
28   public:
29     UnfoldShortCircuitTraverser(TSymbolTable *symbolTable);
30 
31     bool visitBinary(Visit visit, TIntermBinary *node) override;
32     bool visitTernary(Visit visit, TIntermTernary *node) override;
33 
34     void nextIteration();
foundShortCircuit() const35     bool foundShortCircuit() const { return mFoundShortCircuit; }
36 
37   protected:
38     // Marked to true once an operation that needs to be unfolded has been found.
39     // After that, no more unfolding is performed on that traversal.
40     bool mFoundShortCircuit;
41 
42     IntermNodePatternMatcher mPatternToUnfoldMatcher;
43 };
44 
UnfoldShortCircuitTraverser(TSymbolTable * symbolTable)45 UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser(TSymbolTable *symbolTable)
46     : TIntermTraverser(true, false, true, symbolTable),
47       mFoundShortCircuit(false),
48       mPatternToUnfoldMatcher(IntermNodePatternMatcher::kUnfoldedShortCircuitExpression)
49 {}
50 
visitBinary(Visit visit,TIntermBinary * node)51 bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
52 {
53     if (mFoundShortCircuit)
54         return false;
55 
56     if (visit != PreVisit)
57         return true;
58 
59     if (!mPatternToUnfoldMatcher.match(node, getParentNode()))
60         return true;
61 
62     // If our right node doesn't have side effects, we know we don't need to unfold this
63     // expression: there will be no short-circuiting side effects to avoid
64     // (note: unfolding doesn't depend on the left node -- it will always be evaluated)
65     ASSERT(node->getRight()->hasSideEffects());
66 
67     mFoundShortCircuit = true;
68 
69     switch (node->getOp())
70     {
71         case EOpLogicalOr:
72         {
73             // "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true;
74             // else s = y;",
75             // and then further simplifies down to "bool s = x; if(!s) s = y;".
76 
77             TIntermSequence insertions;
78             const TType *boolType = StaticType::Get<EbtBool, EbpUndefined, EvqTemporary, 1, 1>();
79             TVariable *resultVariable = CreateTempVariable(mSymbolTable, boolType);
80 
81             ASSERT(node->getLeft()->getType() == *boolType);
82             insertions.push_back(CreateTempInitDeclarationNode(resultVariable, node->getLeft()));
83 
84             TIntermBlock *assignRightBlock = new TIntermBlock();
85             ASSERT(node->getRight()->getType() == *boolType);
86             assignRightBlock->getSequence()->push_back(
87                 CreateTempAssignmentNode(resultVariable, node->getRight()));
88 
89             TIntermUnary *notTempSymbol =
90                 new TIntermUnary(EOpLogicalNot, CreateTempSymbolNode(resultVariable), nullptr);
91             TIntermIfElse *ifNode = new TIntermIfElse(notTempSymbol, assignRightBlock, nullptr);
92             insertions.push_back(ifNode);
93 
94             insertStatementsInParentBlock(insertions);
95 
96             queueReplacement(CreateTempSymbolNode(resultVariable), OriginalNode::IS_DROPPED);
97             return false;
98         }
99         case EOpLogicalAnd:
100         {
101             // "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y;
102             // else s = false;",
103             // and then further simplifies down to "bool s = x; if(s) s = y;".
104             TIntermSequence insertions;
105             const TType *boolType = StaticType::Get<EbtBool, EbpUndefined, EvqTemporary, 1, 1>();
106             TVariable *resultVariable = CreateTempVariable(mSymbolTable, boolType);
107 
108             ASSERT(node->getLeft()->getType() == *boolType);
109             insertions.push_back(CreateTempInitDeclarationNode(resultVariable, node->getLeft()));
110 
111             TIntermBlock *assignRightBlock = new TIntermBlock();
112             ASSERT(node->getRight()->getType() == *boolType);
113             assignRightBlock->getSequence()->push_back(
114                 CreateTempAssignmentNode(resultVariable, node->getRight()));
115 
116             TIntermIfElse *ifNode =
117                 new TIntermIfElse(CreateTempSymbolNode(resultVariable), assignRightBlock, nullptr);
118             insertions.push_back(ifNode);
119 
120             insertStatementsInParentBlock(insertions);
121 
122             queueReplacement(CreateTempSymbolNode(resultVariable), OriginalNode::IS_DROPPED);
123             return false;
124         }
125         default:
126             UNREACHABLE();
127             return true;
128     }
129 }
130 
visitTernary(Visit visit,TIntermTernary * node)131 bool UnfoldShortCircuitTraverser::visitTernary(Visit visit, TIntermTernary *node)
132 {
133     if (mFoundShortCircuit)
134         return false;
135 
136     if (visit != PreVisit)
137         return true;
138 
139     if (!mPatternToUnfoldMatcher.match(node))
140         return true;
141 
142     mFoundShortCircuit = true;
143 
144     // Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;"
145     TIntermSequence insertions;
146     TIntermDeclaration *tempDeclaration = nullptr;
147     TVariable *resultVariable = DeclareTempVariable(mSymbolTable, new TType(node->getType()),
148                                                     EvqTemporary, &tempDeclaration);
149     insertions.push_back(tempDeclaration);
150 
151     TIntermBlock *trueBlock = new TIntermBlock();
152     TIntermBinary *trueAssignment =
153         CreateTempAssignmentNode(resultVariable, node->getTrueExpression());
154     trueBlock->getSequence()->push_back(trueAssignment);
155 
156     TIntermBlock *falseBlock = new TIntermBlock();
157     TIntermBinary *falseAssignment =
158         CreateTempAssignmentNode(resultVariable, node->getFalseExpression());
159     falseBlock->getSequence()->push_back(falseAssignment);
160 
161     TIntermIfElse *ifNode =
162         new TIntermIfElse(node->getCondition()->getAsTyped(), trueBlock, falseBlock);
163     insertions.push_back(ifNode);
164 
165     insertStatementsInParentBlock(insertions);
166 
167     TIntermSymbol *ternaryResult = CreateTempSymbolNode(resultVariable);
168     queueReplacement(ternaryResult, OriginalNode::IS_DROPPED);
169 
170     return false;
171 }
172 
nextIteration()173 void UnfoldShortCircuitTraverser::nextIteration()
174 {
175     mFoundShortCircuit = false;
176 }
177 
178 }  // namespace
179 
UnfoldShortCircuitToIf(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)180 bool UnfoldShortCircuitToIf(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
181 {
182     UnfoldShortCircuitTraverser traverser(symbolTable);
183     // Unfold one operator at a time, and reset the traverser between iterations.
184     do
185     {
186         traverser.nextIteration();
187         root->traverse(&traverser);
188         if (traverser.foundShortCircuit())
189         {
190             if (!traverser.updateTree(compiler, root))
191             {
192                 return false;
193             }
194         }
195     } while (traverser.foundShortCircuit());
196 
197     return true;
198 }
199 
200 }  // namespace sh
201