1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers.
7 //
8 
9 #include "compiler/translator/tree_ops/vulkan/RewriteArrayOfArrayOfOpaqueUniforms.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
23 struct UniformData
24 {
25     // Corresponding to an array of array of opaque uniform variable, this is the flattened variable
26     // that is replacing it.
27     const TVariable *flattened;
28     // Assume a general case of array declaration with N dimensions:
29     //
30     //     uniform type u[Dn]..[D2][D1];
31     //
32     // Let's define
33     //
34     //     Pn = D(n-1)*...*D2*D1
35     //
36     // In that case, we have:
37     //
38     //     u[In]         = ac + In*Pn
39     //     u[In][I(n-1)] = ac + In*Pn + I(n-1)*P(n-1)
40     //     u[In]...[Ii]  = ac + In*Pn + ... + Ii*Pi
41     //
42     // This array contains Pi.  Note that the like TType::mArraySizes, the last element is the
43     // outermost dimension.  Element 0 is necessarily 1.
44     TVector<unsigned int> mSubArraySizes;
45 };
46 
47 using UniformMap = angle::HashMap<const TVariable *, UniformData>;
48 
49 TIntermTyped *RewriteArrayOfArraySubscriptExpression(TCompiler *compiler,
50                                                      TIntermBinary *node,
51                                                      const UniformMap &uniformMap);
52 
53 // Given an expression, this traverser calculates a new expression where array of array of opaque
54 // uniforms are replaced with their flattened ones.  In particular, this is run on the right node of
55 // EOpIndexIndirect binary nodes, so that the expression in the index gets a chance to go through
56 // this transformation.
57 class RewriteExpressionTraverser final : public TIntermTraverser
58 {
59   public:
RewriteExpressionTraverser(TCompiler * compiler,const UniformMap & uniformMap)60     explicit RewriteExpressionTraverser(TCompiler *compiler, const UniformMap &uniformMap)
61         : TIntermTraverser(true, false, false), mCompiler(compiler), mUniformMap(uniformMap)
62     {}
63 
visitBinary(Visit visit,TIntermBinary * node)64     bool visitBinary(Visit visit, TIntermBinary *node) override
65     {
66         TIntermTyped *rewritten =
67             RewriteArrayOfArraySubscriptExpression(mCompiler, node, mUniformMap);
68         if (rewritten == nullptr)
69         {
70             return true;
71         }
72 
73         queueReplacement(rewritten, OriginalNode::IS_DROPPED);
74 
75         // Don't iterate as the expression is rewritten.
76         return false;
77     }
78 
visitSymbol(TIntermSymbol * node)79     void visitSymbol(TIntermSymbol *node) override
80     {
81         // We cannot reach here for an opaque uniform that is being replaced.  visitBinary should
82         // have taken care of it.
83         ASSERT(!IsOpaqueType(node->getType().getBasicType()) ||
84                mUniformMap.find(&node->variable()) == mUniformMap.end());
85     }
86 
87   private:
88     TCompiler *mCompiler;
89 
90     const UniformMap &mUniformMap;
91 };
92 
93 // Rewrite the index of an EOpIndexIndirect expression.  The root can never need replacing, because
94 // it cannot be an opaque uniform itself.
RewriteIndexExpression(TCompiler * compiler,TIntermTyped * expression,const UniformMap & uniformMap)95 void RewriteIndexExpression(TCompiler *compiler,
96                             TIntermTyped *expression,
97                             const UniformMap &uniformMap)
98 {
99     RewriteExpressionTraverser traverser(compiler, uniformMap);
100     expression->traverse(&traverser);
101     bool valid = traverser.updateTree(compiler, expression);
102     ASSERT(valid);
103 }
104 
105 // Given an expression such as the following:
106 //
107 //                                              EOpIndex(In)Direct (opaque uniform)
108 //                                                    /           \
109 //                                            EOpIndex(In)Direct   I1
110 //                                                  /           \
111 //                                                ...            I2
112 //                                            /
113 //                                    EOpIndex(In)Direct
114 //                                          /           \
115 //                                      uniform          In
116 //
117 // produces:
118 //
119 //          EOpIndex(In)Direct
120 //            /        \
121 //        uniform    In*Pn + ... + I2*P2 + I1*P1
122 //
RewriteArrayOfArraySubscriptExpression(TCompiler * compiler,TIntermBinary * node,const UniformMap & uniformMap)123 TIntermTyped *RewriteArrayOfArraySubscriptExpression(TCompiler *compiler,
124                                                      TIntermBinary *node,
125                                                      const UniformMap &uniformMap)
126 {
127     // Only interested in opaque uniforms.
128     if (!IsOpaqueType(node->getType().getBasicType()))
129     {
130         return nullptr;
131     }
132 
133     TIntermSymbol *opaqueUniform = nullptr;
134 
135     // Iterate once and find the opaque uniform that's being indexed.
136     TIntermBinary *iter = node;
137     while (opaqueUniform == nullptr)
138     {
139         ASSERT(iter->getOp() == EOpIndexDirect || iter->getOp() == EOpIndexIndirect);
140 
141         opaqueUniform = iter->getLeft()->getAsSymbolNode();
142         iter          = iter->getLeft()->getAsBinaryNode();
143     }
144 
145     // If not being replaced, there's nothing to do.
146     auto flattenedIter = uniformMap.find(&opaqueUniform->variable());
147     if (flattenedIter == uniformMap.end())
148     {
149         return nullptr;
150     }
151 
152     const UniformData &data = flattenedIter->second;
153 
154     // Iterate again and build the index expression.  The index expression constitutes the sum of
155     // the variable indices plus a constant offset calculated from the constant indices.  For
156     // example, smplr[1][x][2][y] will have an index of x*P3 + y*P1 + c, where c = (1*P4 + 2*P2).
157     unsigned int constantOffset = 0;
158     TIntermTyped *variableIndex = nullptr;
159 
160     // Since the opaque uniforms are fully subscripted, we know exactly how many EOpIndex* nodes
161     // there should be.
162     for (size_t dimIndex = 0; dimIndex < data.mSubArraySizes.size(); ++dimIndex)
163     {
164         ASSERT(node);
165 
166         unsigned int subArraySize = data.mSubArraySizes[dimIndex];
167 
168         switch (node->getOp())
169         {
170             case EOpIndexDirect:
171                 // Accumulate the constant index.
172                 constantOffset +=
173                     node->getRight()->getAsConstantUnion()->getIConst(0) * subArraySize;
174                 break;
175             case EOpIndexIndirect:
176             {
177                 // Run RewriteExpressionTraverser on the right node.  It may itself be an expression
178                 // with an array of array of opaque uniform inside that needs to be rewritten.
179                 TIntermTyped *indexExpression = node->getRight();
180                 RewriteIndexExpression(compiler, indexExpression, uniformMap);
181 
182                 // Scale and accumulate.
183                 if (subArraySize != 1)
184                 {
185                     indexExpression =
186                         new TIntermBinary(EOpMul, indexExpression, CreateIndexNode(subArraySize));
187                 }
188 
189                 if (variableIndex == nullptr)
190                 {
191                     variableIndex = indexExpression;
192                 }
193                 else
194                 {
195                     variableIndex = new TIntermBinary(EOpAdd, variableIndex, indexExpression);
196                 }
197                 break;
198             }
199             default:
200                 UNREACHABLE();
201                 break;
202         }
203 
204         node = node->getLeft()->getAsBinaryNode();
205     }
206 
207     // Add the two accumulated indices together.
208     TIntermTyped *index = nullptr;
209     if (constantOffset == 0 && variableIndex != nullptr)
210     {
211         // No constant offset, but there's variable offset.  Take that as offset.
212         index = variableIndex;
213     }
214     else
215     {
216         // Either the constant offset is non zero, or there's no variable offset (so constant 0
217         // should be used).
218         index = CreateIndexNode(constantOffset);
219 
220         if (variableIndex)
221         {
222             index = new TIntermBinary(EOpAdd, index, variableIndex);
223         }
224     }
225 
226     // Create an index into the flattened uniform.
227     TOperator op = variableIndex ? EOpIndexIndirect : EOpIndexDirect;
228     return new TIntermBinary(op, new TIntermSymbol(data.flattened), index);
229 }
230 
231 // Traverser that takes:
232 //
233 //     uniform sampler/image/atomic_uint u[N][M]..
234 //
235 // and transforms it to:
236 //
237 //     uniform sampler/image/atomic_uint u[N * M * ..]
238 //
239 // MonomorphizeUnsupportedFunctionsInVulkanGLSL makes it impossible for this array to be partially
240 // subscripted, or passed as argument to a function unsubscripted.  This means that every encounter
241 // of this uniform can be expected to be fully subscripted.
242 //
243 class RewriteArrayOfArrayOfOpaqueUniformsTraverser : public TIntermTraverser
244 {
245   public:
RewriteArrayOfArrayOfOpaqueUniformsTraverser(TCompiler * compiler,TSymbolTable * symbolTable)246     RewriteArrayOfArrayOfOpaqueUniformsTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
247         : TIntermTraverser(true, false, false, symbolTable), mCompiler(compiler)
248     {}
249 
visitDeclaration(Visit visit,TIntermDeclaration * node)250     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
251     {
252         if (!mInGlobalScope)
253         {
254             return true;
255         }
256 
257         const TIntermSequence &sequence = *(node->getSequence());
258 
259         TIntermTyped *variable = sequence.front()->getAsTyped();
260         const TType &type      = variable->getType();
261         bool isOpaqueUniform =
262             type.getQualifier() == EvqUniform && IsOpaqueType(type.getBasicType());
263 
264         // Only interested in array of array of opaque uniforms.
265         if (!isOpaqueUniform || !type.isArrayOfArrays())
266         {
267             return false;
268         }
269 
270         // Opaque uniforms cannot have initializers, so the declaration must necessarily be a
271         // symbol.
272         TIntermSymbol *symbol = variable->getAsSymbolNode();
273         ASSERT(symbol != nullptr);
274 
275         const TVariable *uniformVariable = &symbol->variable();
276 
277         // Create an entry in the map.
278         ASSERT(mUniformMap.find(uniformVariable) == mUniformMap.end());
279         UniformData &data = mUniformMap[uniformVariable];
280 
281         // Calculate the accumulated dimension products.  See UniformData::mSubArraySizes.
282         const TSpan<const unsigned int> &arraySizes = type.getArraySizes();
283         mUniformMap[uniformVariable].mSubArraySizes.resize(arraySizes.size());
284         unsigned int runningProduct = 1;
285         for (size_t dimension = 0; dimension < arraySizes.size(); ++dimension)
286         {
287             data.mSubArraySizes[dimension] = runningProduct;
288             runningProduct *= arraySizes[dimension];
289         }
290 
291         // Create a replacement variable with the array flattened.
292         TType *newType = new TType(type);
293         newType->toArrayBaseType();
294         newType->makeArray(runningProduct);
295 
296         data.flattened = new TVariable(mSymbolTable, uniformVariable->name(), newType,
297                                        uniformVariable->symbolType());
298 
299         TIntermDeclaration *decl = new TIntermDeclaration;
300         decl->appendDeclarator(new TIntermSymbol(data.flattened));
301 
302         queueReplacement(decl, OriginalNode::IS_DROPPED);
303         return false;
304     }
305 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)306     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
307     {
308         // As an optimization, don't bother inspecting functions if there aren't any opaque uniforms
309         // to replace.
310         return !mUniformMap.empty();
311     }
312 
313     // Same implementation as in RewriteExpressionTraverser.  That traverser cannot replace root.
visitBinary(Visit visit,TIntermBinary * node)314     bool visitBinary(Visit visit, TIntermBinary *node) override
315     {
316         TIntermTyped *rewritten =
317             RewriteArrayOfArraySubscriptExpression(mCompiler, node, mUniformMap);
318         if (rewritten == nullptr)
319         {
320             return true;
321         }
322 
323         queueReplacement(rewritten, OriginalNode::IS_DROPPED);
324 
325         // Don't iterate as the expression is rewritten.
326         return false;
327     }
328 
visitSymbol(TIntermSymbol * node)329     void visitSymbol(TIntermSymbol *node) override
330     {
331         ASSERT(!IsOpaqueType(node->getType().getBasicType()) ||
332                mUniformMap.find(&node->variable()) == mUniformMap.end());
333     }
334 
335   private:
336     TCompiler *mCompiler;
337     UniformMap mUniformMap;
338 };
339 }  // anonymous namespace
340 
RewriteArrayOfArrayOfOpaqueUniforms(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)341 bool RewriteArrayOfArrayOfOpaqueUniforms(TCompiler *compiler,
342                                          TIntermBlock *root,
343                                          TSymbolTable *symbolTable)
344 {
345     RewriteArrayOfArrayOfOpaqueUniformsTraverser traverser(compiler, symbolTable);
346     root->traverse(&traverser);
347     return traverser.updateTree(compiler, root);
348 }
349 }  // namespace sh
350