1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteExpressionsWithShaderStorageBlock rewrites the expressions that contain shader storage
7 // block calls into several simple ones that can be easily handled in the HLSL translator. After the
8 // AST pass, all ssbo related blocks will be like below:
9 //     ssbo_access_chain = ssbo_access_chain;
10 //     ssbo_access_chain = expr_no_ssbo;
11 //     lvalue_no_ssbo    = ssbo_access_chain;
12 //
13 
14 #include "compiler/translator/tree_ops/d3d/RewriteExpressionsWithShaderStorageBlock.h"
15 
16 #include "compiler/translator/Symbol.h"
17 #include "compiler/translator/tree_util/IntermNode_util.h"
18 #include "compiler/translator/tree_util/IntermTraverse.h"
19 #include "compiler/translator/util.h"
20 
21 namespace sh
22 {
23 namespace
24 {
25 
IsIncrementOrDecrementOperator(TOperator op)26 bool IsIncrementOrDecrementOperator(TOperator op)
27 {
28     switch (op)
29     {
30         case EOpPostIncrement:
31         case EOpPostDecrement:
32         case EOpPreIncrement:
33         case EOpPreDecrement:
34             return true;
35         default:
36             return false;
37     }
38 }
39 
IsCompoundAssignment(TOperator op)40 bool IsCompoundAssignment(TOperator op)
41 {
42     switch (op)
43     {
44         case EOpAddAssign:
45         case EOpSubAssign:
46         case EOpMulAssign:
47         case EOpVectorTimesMatrixAssign:
48         case EOpVectorTimesScalarAssign:
49         case EOpMatrixTimesScalarAssign:
50         case EOpMatrixTimesMatrixAssign:
51         case EOpDivAssign:
52         case EOpIModAssign:
53         case EOpBitShiftLeftAssign:
54         case EOpBitShiftRightAssign:
55         case EOpBitwiseAndAssign:
56         case EOpBitwiseXorAssign:
57         case EOpBitwiseOrAssign:
58             return true;
59         default:
60             return false;
61     }
62 }
63 
64 // EOpIndexDirect, EOpIndexIndirect, EOpIndexDirectStruct, EOpIndexDirectInterfaceBlock belong to
65 // operators in SSBO access chain.
IsReadonlyBinaryOperatorNotInSSBOAccessChain(TOperator op)66 bool IsReadonlyBinaryOperatorNotInSSBOAccessChain(TOperator op)
67 {
68     switch (op)
69     {
70         case EOpComma:
71         case EOpAdd:
72         case EOpSub:
73         case EOpMul:
74         case EOpDiv:
75         case EOpIMod:
76         case EOpBitShiftLeft:
77         case EOpBitShiftRight:
78         case EOpBitwiseAnd:
79         case EOpBitwiseXor:
80         case EOpBitwiseOr:
81         case EOpEqual:
82         case EOpNotEqual:
83         case EOpLessThan:
84         case EOpGreaterThan:
85         case EOpLessThanEqual:
86         case EOpGreaterThanEqual:
87         case EOpVectorTimesScalar:
88         case EOpMatrixTimesScalar:
89         case EOpVectorTimesMatrix:
90         case EOpMatrixTimesVector:
91         case EOpMatrixTimesMatrix:
92         case EOpLogicalOr:
93         case EOpLogicalXor:
94         case EOpLogicalAnd:
95             return true;
96         default:
97             return false;
98     }
99 }
100 
HasSSBOAsFunctionArgument(TIntermSequence * arguments)101 bool HasSSBOAsFunctionArgument(TIntermSequence *arguments)
102 {
103     for (TIntermNode *arg : *arguments)
104     {
105         TIntermTyped *typedArg = arg->getAsTyped();
106         if (IsInShaderStorageBlock(typedArg))
107         {
108             return true;
109         }
110     }
111     return false;
112 }
113 
114 class RewriteExpressionsWithShaderStorageBlockTraverser : public TIntermTraverser
115 {
116   public:
117     RewriteExpressionsWithShaderStorageBlockTraverser(TSymbolTable *symbolTable);
118     void nextIteration();
foundSSBO() const119     bool foundSSBO() const { return mFoundSSBO; }
120 
121   private:
122     bool visitBinary(Visit, TIntermBinary *node) override;
123     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
124     bool visitUnary(Visit visit, TIntermUnary *node) override;
125 
126     TIntermSymbol *insertInitStatementAndReturnTempSymbol(TIntermTyped *node,
127                                                           TIntermSequence *insertions);
128 
129     bool mFoundSSBO;
130 };
131 
132 RewriteExpressionsWithShaderStorageBlockTraverser::
RewriteExpressionsWithShaderStorageBlockTraverser(TSymbolTable * symbolTable)133     RewriteExpressionsWithShaderStorageBlockTraverser(TSymbolTable *symbolTable)
134     : TIntermTraverser(true, true, false, symbolTable), mFoundSSBO(false)
135 {}
136 
137 TIntermSymbol *
insertInitStatementAndReturnTempSymbol(TIntermTyped * node,TIntermSequence * insertions)138 RewriteExpressionsWithShaderStorageBlockTraverser::insertInitStatementAndReturnTempSymbol(
139     TIntermTyped *node,
140     TIntermSequence *insertions)
141 {
142     TIntermDeclaration *variableDeclaration;
143     TVariable *tempVariable =
144         DeclareTempVariable(mSymbolTable, node, EvqTemporary, &variableDeclaration);
145 
146     insertions->push_back(variableDeclaration);
147     return CreateTempSymbolNode(tempVariable);
148 }
149 
visitBinary(Visit visit,TIntermBinary * node)150 bool RewriteExpressionsWithShaderStorageBlockTraverser::visitBinary(Visit visit,
151                                                                     TIntermBinary *node)
152 {
153     // Make sure that the expression is caculated from left to right.
154     if (visit != InVisit)
155     {
156         return true;
157     }
158 
159     if (mFoundSSBO)
160     {
161         return false;
162     }
163 
164     bool rightSSBO = IsInShaderStorageBlock(node->getRight());
165     bool leftSSBO  = IsInShaderStorageBlock(node->getLeft());
166     if (!leftSSBO && !rightSSBO)
167     {
168         return true;
169     }
170 
171     // case 1: Compound assigment operator
172     //  original:
173     //      lssbo += expr;
174     //  new:
175     //      var rvalue = expr;
176     //      var temp = lssbo;
177     //      temp += rvalue;
178     //      lssbo = temp;
179     //
180     //  original:
181     //      lvalue_no_ssbo += rssbo;
182     //  new:
183     //      var rvalue = rssbo;
184     //      lvalue_no_ssbo += rvalue;
185     if (IsCompoundAssignment(node->getOp()))
186     {
187         mFoundSSBO = true;
188         TIntermSequence insertions;
189         TIntermTyped *rightNode =
190             insertInitStatementAndReturnTempSymbol(node->getRight(), &insertions);
191         if (leftSSBO)
192         {
193             TIntermSymbol *tempSymbol =
194                 insertInitStatementAndReturnTempSymbol(node->getLeft()->deepCopy(), &insertions);
195             TIntermBinary *tempCompoundOperate =
196                 new TIntermBinary(node->getOp(), tempSymbol->deepCopy(), rightNode->deepCopy());
197             insertions.push_back(tempCompoundOperate);
198             insertStatementsInParentBlock(insertions);
199 
200             TIntermBinary *assignTempValueToSSBO =
201                 new TIntermBinary(EOpAssign, node->getLeft(), tempSymbol->deepCopy());
202             queueReplacement(assignTempValueToSSBO, OriginalNode::IS_DROPPED);
203         }
204         else
205         {
206             insertStatementsInParentBlock(insertions);
207             TIntermBinary *compoundAssignRValueToLValue =
208                 new TIntermBinary(node->getOp(), node->getLeft(), rightNode->deepCopy());
209             queueReplacement(compoundAssignRValueToLValue, OriginalNode::IS_DROPPED);
210         }
211     }
212     // case 2: Readonly binary operator
213     //  original:
214     //      ssbo0 + ssbo1 + ssbo2;
215     //  new:
216     //      var temp0 = ssbo0;
217     //      var temp1 = ssbo1;
218     //      var temp2 = ssbo2;
219     //      temp0 + temp1 + temp2;
220     else if (IsReadonlyBinaryOperatorNotInSSBOAccessChain(node->getOp()) && (leftSSBO || rightSSBO))
221     {
222         mFoundSSBO              = true;
223         TIntermTyped *rightNode = node->getRight();
224         TIntermTyped *leftNode  = node->getLeft();
225         TIntermSequence insertions;
226         if (rightSSBO)
227         {
228             rightNode = insertInitStatementAndReturnTempSymbol(node->getRight(), &insertions);
229         }
230         if (leftSSBO)
231         {
232             leftNode = insertInitStatementAndReturnTempSymbol(node->getLeft(), &insertions);
233         }
234 
235         insertStatementsInParentBlock(insertions);
236         TIntermBinary *newExpr =
237             new TIntermBinary(node->getOp(), leftNode->deepCopy(), rightNode->deepCopy());
238         queueReplacement(newExpr, OriginalNode::IS_DROPPED);
239     }
240     return !mFoundSSBO;
241 }
242 
243 // case 3: ssbo as the argument of aggregate type
244 //  original:
245 //      foo(ssbo);
246 //  new:
247 //      var tempArg = ssbo;
248 //      foo(tempArg);
249 //      ssbo = tempArg;  (Optional based on whether ssbo is an out|input argument)
250 //
251 //  original:
252 //      foo(ssbo) * expr;
253 //  new:
254 //      var tempArg = ssbo;
255 //      var tempReturn = foo(tempArg);
256 //      ssbo = tempArg;  (Optional based on whether ssbo is an out|input argument)
257 //      tempReturn * expr;
visitAggregate(Visit visit,TIntermAggregate * node)258 bool RewriteExpressionsWithShaderStorageBlockTraverser::visitAggregate(Visit visit,
259                                                                        TIntermAggregate *node)
260 {
261     // Make sure that visitAggregate is only executed once for same node.
262     if (visit != PreVisit)
263     {
264         return true;
265     }
266 
267     if (mFoundSSBO)
268     {
269         return false;
270     }
271 
272     // We still need to process the ssbo as the non-first argument of atomic memory functions.
273     if (BuiltInGroup::IsAtomicMemory(node->getOp()) &&
274         IsInShaderStorageBlock((*node->getSequence())[0]->getAsTyped()))
275     {
276         return true;
277     }
278 
279     if (!HasSSBOAsFunctionArgument(node->getSequence()))
280     {
281         return true;
282     }
283 
284     mFoundSSBO = true;
285     TIntermSequence insertions;
286     TIntermSequence readBackToSSBOs;
287     TIntermSequence *originalArguments = node->getSequence();
288     for (size_t i = 0; i < node->getChildCount(); ++i)
289     {
290         TIntermTyped *ssboArgument = (*originalArguments)[i]->getAsTyped();
291         if (IsInShaderStorageBlock(ssboArgument))
292         {
293             TIntermSymbol *argumentCopy =
294                 insertInitStatementAndReturnTempSymbol(ssboArgument, &insertions);
295             if (node->getFunction() != nullptr)
296             {
297                 TQualifier qual = node->getFunction()->getParam(i)->getType().getQualifier();
298                 if (qual == EvqInOut || qual == EvqOut)
299                 {
300                     TIntermBinary *readBackToSSBO = new TIntermBinary(
301                         EOpAssign, ssboArgument->deepCopy(), argumentCopy->deepCopy());
302                     readBackToSSBOs.push_back(readBackToSSBO);
303                 }
304             }
305             node->replaceChildNode(ssboArgument, argumentCopy);
306         }
307     }
308 
309     TIntermBlock *parentBlock = getParentNode()->getAsBlock();
310     if (parentBlock)
311     {
312         // Aggregate node is as a single sentence.
313         insertions.push_back(node);
314         if (!readBackToSSBOs.empty())
315         {
316             insertions.insert(insertions.end(), readBackToSSBOs.begin(), readBackToSSBOs.end());
317         }
318         mMultiReplacements.emplace_back(parentBlock, node, std::move(insertions));
319     }
320     else
321     {
322         // Aggregate node is inside an expression.
323         TIntermSymbol *tempSymbol = insertInitStatementAndReturnTempSymbol(node, &insertions);
324         if (!readBackToSSBOs.empty())
325         {
326             insertions.insert(insertions.end(), readBackToSSBOs.begin(), readBackToSSBOs.end());
327         }
328         insertStatementsInParentBlock(insertions);
329         queueReplacement(tempSymbol->deepCopy(), OriginalNode::IS_DROPPED);
330     }
331 
332     return false;
333 }
334 
visitUnary(Visit visit,TIntermUnary * node)335 bool RewriteExpressionsWithShaderStorageBlockTraverser::visitUnary(Visit visit, TIntermUnary *node)
336 {
337     if (mFoundSSBO)
338     {
339         return false;
340     }
341 
342     if (!IsInShaderStorageBlock(node->getOperand()))
343     {
344         return true;
345     }
346 
347     // .length() is processed in OutputHLSL.
348     if (node->getOp() == EOpArrayLength)
349     {
350         return true;
351     }
352 
353     mFoundSSBO = true;
354 
355     // case 4: ssbo as the operand of ++/--
356     //  original:
357     //      ++ssbo * expr;
358     //  new:
359     //      var temp1 = ssbo;
360     //      var temp2 = ++temp1;
361     //      ssbo = temp1;
362     //      temp2 * expr;
363     if (IsIncrementOrDecrementOperator(node->getOp()))
364     {
365         TIntermSequence insertions;
366         TIntermSymbol *temp1 =
367             insertInitStatementAndReturnTempSymbol(node->getOperand(), &insertions);
368         TIntermUnary *newUnary = new TIntermUnary(node->getOp(), temp1->deepCopy(), nullptr);
369         TIntermSymbol *temp2   = insertInitStatementAndReturnTempSymbol(newUnary, &insertions);
370         TIntermBinary *readBackToSSBO =
371             new TIntermBinary(EOpAssign, node->getOperand()->deepCopy(), temp1->deepCopy());
372         insertions.push_back(readBackToSSBO);
373         insertStatementsInParentBlock(insertions);
374         queueReplacement(temp2->deepCopy(), OriginalNode::IS_DROPPED);
375     }
376     // case 5: ssbo as the operand of readonly unary operator
377     //  original:
378     //      ~ssbo * expr;
379     //  new:
380     //      var temp = ssbo;
381     //      ~temp * expr;
382     else
383     {
384         TIntermSequence insertions;
385         TIntermSymbol *temp =
386             insertInitStatementAndReturnTempSymbol(node->getOperand(), &insertions);
387         insertStatementsInParentBlock(insertions);
388         node->replaceChildNode(node->getOperand(), temp->deepCopy());
389     }
390     return false;
391 }
392 
nextIteration()393 void RewriteExpressionsWithShaderStorageBlockTraverser::nextIteration()
394 {
395     mFoundSSBO = false;
396 }
397 
398 }  // anonymous namespace
399 
RewriteExpressionsWithShaderStorageBlock(TCompiler * compiler,TIntermNode * root,TSymbolTable * symbolTable)400 bool RewriteExpressionsWithShaderStorageBlock(TCompiler *compiler,
401                                               TIntermNode *root,
402                                               TSymbolTable *symbolTable)
403 {
404     RewriteExpressionsWithShaderStorageBlockTraverser traverser(symbolTable);
405     do
406     {
407         traverser.nextIteration();
408         root->traverse(&traverser);
409         if (traverser.foundSSBO())
410         {
411             if (!traverser.updateTree(compiler, root))
412             {
413                 return false;
414             }
415         }
416     } while (traverser.foundSSBO());
417 
418     return true;
419 }
420 }  // namespace sh
421