1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers.
7 //
8 
9 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
23 constexpr ImmutableString kAtomicCountersVarName  = ImmutableString("atomicCounters");
24 constexpr ImmutableString kAtomicCounterFieldName = ImmutableString("counters");
25 
26 // DeclareAtomicCountersBuffer adds a storage buffer array that's used with atomic counters.
DeclareAtomicCountersBuffers(TIntermBlock * root,TSymbolTable * symbolTable)27 const TVariable *DeclareAtomicCountersBuffers(TIntermBlock *root, TSymbolTable *symbolTable)
28 {
29     // Define `uint counters[];` as the only field in the interface block.
30     TFieldList *fieldList = new TFieldList;
31     TType *counterType    = new TType(EbtUInt);
32     counterType->makeArray(0);
33 
34     TField *countersField =
35         new TField(counterType, kAtomicCounterFieldName, TSourceLoc(), SymbolType::AngleInternal);
36 
37     fieldList->push_back(countersField);
38 
39     TMemoryQualifier coherentMemory = TMemoryQualifier::Create();
40     coherentMemory.coherent         = true;
41 
42     // There are a maximum of 8 atomic counter buffers per IMPLEMENTATION_MAX_ATOMIC_COUNTER_BUFFERS
43     // in libANGLE/Constants.h.
44     constexpr uint32_t kMaxAtomicCounterBuffers = 8;
45 
46     // Define a storage block "ANGLEAtomicCounters" with instance name "atomicCounters".
47     return DeclareInterfaceBlock(
48         root, symbolTable, fieldList, EvqBuffer, TLayoutQualifier::Create(), coherentMemory,
49         kMaxAtomicCounterBuffers, ImmutableString(vk::kAtomicCountersBlockName),
50         kAtomicCountersVarName);
51 }
52 
CreateUniformBufferOffset(const TIntermTyped * uniformBufferOffsets,int binding)53 TIntermTyped *CreateUniformBufferOffset(const TIntermTyped *uniformBufferOffsets, int binding)
54 {
55     // Each uint in the |acbBufferOffsets| uniform contains offsets for 4 bindings.  Therefore, the
56     // expression to get the uniform offset for the binding is:
57     //
58     //     acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
59 
60     // acbBufferOffsets[binding / 4]
61     TIntermBinary *uniformBufferOffsetUint = new TIntermBinary(
62         EOpIndexDirect, uniformBufferOffsets->deepCopy(), CreateIndexNode(binding / 4));
63 
64     // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8)
65     TIntermBinary *uniformBufferOffsetShifted = uniformBufferOffsetUint;
66     if (binding % 4 != 0)
67     {
68         uniformBufferOffsetShifted = new TIntermBinary(EOpBitShiftRight, uniformBufferOffsetUint,
69                                                        CreateUIntNode((binding % 4) * 8));
70     }
71 
72     // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
73     return new TIntermBinary(EOpBitwiseAnd, uniformBufferOffsetShifted, CreateUIntNode(0xFF));
74 }
75 
CreateAtomicCounterRef(TIntermTyped * atomicCounterExpression,const TVariable * atomicCounters,const TIntermTyped * uniformBufferOffsets)76 TIntermBinary *CreateAtomicCounterRef(TIntermTyped *atomicCounterExpression,
77                                       const TVariable *atomicCounters,
78                                       const TIntermTyped *uniformBufferOffsets)
79 {
80     // The atomic counters storage buffer declaration looks as such:
81     //
82     // layout(...) buffer ANGLEAtomicCounters
83     // {
84     //     uint counters[];
85     // } atomicCounters[N];
86     //
87     // Where N is large enough to accommodate atomic counter buffer bindings used in the shader.
88     //
89     // This function takes an expression that uses an atomic counter, which can either be:
90     //
91     //  - ac
92     //  - acArray[index]
93     //
94     // Note that RewriteArrayOfArrayOfOpaqueUniforms has already flattened array of array of atomic
95     // counters.
96     //
97     // For the first case (ac), the following code is generated:
98     //
99     //     atomicCounters[binding].counters[offset]
100     //
101     // For the second case (acArray[index]), the following code is generated:
102     //
103     //     atomicCounters[binding].counters[offset + index]
104     //
105     // In either case, an offset given through uniforms is also added to |offset|.  The binding is
106     // necessarily a constant thanks to MonomorphizeUnsupportedFunctionsInVulkanGLSL.
107 
108     // First determine if there's an index, and extract the atomic counter symbol out of the
109     // expression.
110     TIntermSymbol *atomicCounterSymbol = atomicCounterExpression->getAsSymbolNode();
111     TIntermTyped *atomicCounterIndex   = nullptr;
112     int atomicCounterConstIndex        = 0;
113     TIntermBinary *asBinary            = atomicCounterExpression->getAsBinaryNode();
114     if (asBinary != nullptr)
115     {
116         atomicCounterSymbol = asBinary->getLeft()->getAsSymbolNode();
117 
118         switch (asBinary->getOp())
119         {
120             case EOpIndexDirect:
121                 atomicCounterConstIndex = asBinary->getRight()->getAsConstantUnion()->getIConst(0);
122                 break;
123             case EOpIndexIndirect:
124                 atomicCounterIndex = asBinary->getRight();
125                 break;
126             default:
127                 UNREACHABLE();
128         }
129     }
130 
131     // Extract binding and offset information out of the atomic counter symbol.
132     ASSERT(atomicCounterSymbol);
133     const TVariable *atomicCounterVar = &atomicCounterSymbol->variable();
134     const TType &atomicCounterType    = atomicCounterVar->getType();
135 
136     const int binding = atomicCounterType.getLayoutQualifier().binding;
137     int offset        = atomicCounterType.getLayoutQualifier().offset / 4;
138 
139     // Create the expression:
140     //
141     //     offset + arrayIndex + uniformOffset
142     //
143     // If arrayIndex is a constant, it's added with offset right here.
144 
145     offset += atomicCounterConstIndex;
146 
147     TIntermTyped *index = CreateUniformBufferOffset(uniformBufferOffsets, binding);
148     if (atomicCounterIndex != nullptr)
149     {
150         index = new TIntermBinary(EOpAdd, index, atomicCounterIndex);
151     }
152     if (offset != 0)
153     {
154         index = new TIntermBinary(EOpAdd, index, CreateIndexNode(offset));
155     }
156 
157     // Finally, create the complete expression:
158     //
159     //     atomicCounters[binding].counters[index]
160 
161     TIntermSymbol *atomicCountersRef = new TIntermSymbol(atomicCounters);
162 
163     // atomicCounters[binding]
164     TIntermBinary *countersBlock =
165         new TIntermBinary(EOpIndexDirect, atomicCountersRef, CreateIndexNode(binding));
166 
167     // atomicCounters[binding].counters
168     TIntermBinary *counters =
169         new TIntermBinary(EOpIndexDirectInterfaceBlock, countersBlock, CreateIndexNode(0));
170 
171     return new TIntermBinary(EOpIndexIndirect, counters, index);
172 }
173 
174 // Traverser that:
175 //
176 // 1. Removes the |uniform atomic_uint| declarations and remembers the binding and offset.
177 // 2. Substitutes |atomicVar[n]| with |buffer[binding].counters[offset + n]|.
178 class RewriteAtomicCountersTraverser : public TIntermTraverser
179 {
180   public:
RewriteAtomicCountersTraverser(TSymbolTable * symbolTable,const TVariable * atomicCounters,const TIntermTyped * acbBufferOffsets)181     RewriteAtomicCountersTraverser(TSymbolTable *symbolTable,
182                                    const TVariable *atomicCounters,
183                                    const TIntermTyped *acbBufferOffsets)
184         : TIntermTraverser(true, false, false, symbolTable),
185           mAtomicCounters(atomicCounters),
186           mAcbBufferOffsets(acbBufferOffsets)
187     {}
188 
visitDeclaration(Visit visit,TIntermDeclaration * node)189     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
190     {
191         if (!mInGlobalScope)
192         {
193             return true;
194         }
195 
196         const TIntermSequence &sequence = *(node->getSequence());
197 
198         TIntermTyped *variable = sequence.front()->getAsTyped();
199         const TType &type      = variable->getType();
200         bool isAtomicCounter   = type.isAtomicCounter();
201 
202         if (isAtomicCounter)
203         {
204             ASSERT(type.getQualifier() == EvqUniform);
205             TIntermSequence emptySequence;
206             mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
207                                             std::move(emptySequence));
208 
209             return false;
210         }
211 
212         return true;
213     }
214 
visitAggregate(Visit visit,TIntermAggregate * node)215     bool visitAggregate(Visit visit, TIntermAggregate *node) override
216     {
217         if (BuiltInGroup::IsBuiltIn(node->getOp()))
218         {
219             bool converted = convertBuiltinFunction(node);
220             return !converted;
221         }
222 
223         // AST functions don't require modification as atomic counter function parameters are
224         // removed by MonomorphizeUnsupportedFunctionsInVulkanGLSL.
225         return true;
226     }
227 
visitSymbol(TIntermSymbol * symbol)228     void visitSymbol(TIntermSymbol *symbol) override
229     {
230         // Cannot encounter the atomic counter symbol directly.  It can only be used with functions,
231         // and therefore it's handled by visitAggregate.
232         ASSERT(!symbol->getType().isAtomicCounter());
233     }
234 
visitBinary(Visit visit,TIntermBinary * node)235     bool visitBinary(Visit visit, TIntermBinary *node) override
236     {
237         // Cannot encounter an atomic counter expression directly.  It can only be used with
238         // functions, and therefore it's handled by visitAggregate.
239         ASSERT(!node->getType().isAtomicCounter());
240         return true;
241     }
242 
243   private:
convertBuiltinFunction(TIntermAggregate * node)244     bool convertBuiltinFunction(TIntermAggregate *node)
245     {
246         const TOperator op = node->getOp();
247 
248         // If the function is |memoryBarrierAtomicCounter|, simply replace it with
249         // |memoryBarrierBuffer|.
250         if (op == EOpMemoryBarrierAtomicCounter)
251         {
252             TIntermSequence emptySequence;
253             TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
254                 "memoryBarrierBuffer", &emptySequence, *mSymbolTable, 310);
255             queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
256             return true;
257         }
258 
259         // If it's an |atomicCounter*| function, replace the function with an |atomic*| equivalent.
260         if (!node->getFunction()->isAtomicCounterFunction())
261         {
262             return false;
263         }
264 
265         // Note: atomicAdd(0) is used for atomic reads.
266         uint32_t valueChange                = 0;
267         constexpr char kAtomicAddFunction[] = "atomicAdd";
268         bool isDecrement                    = false;
269 
270         if (op == EOpAtomicCounterIncrement)
271         {
272             valueChange = 1;
273         }
274         else if (op == EOpAtomicCounterDecrement)
275         {
276             // uint values are required to wrap around, so 0xFFFFFFFFu is used as -1.
277             valueChange = std::numeric_limits<uint32_t>::max();
278             static_assert(static_cast<uint32_t>(-1) == std::numeric_limits<uint32_t>::max(),
279                           "uint32_t max is not -1");
280 
281             isDecrement = true;
282         }
283         else
284         {
285             ASSERT(op == EOpAtomicCounter);
286         }
287 
288         TIntermTyped *param = (*node->getSequence())[0]->getAsTyped();
289 
290         TIntermSequence substituteArguments;
291         substituteArguments.push_back(
292             CreateAtomicCounterRef(param, mAtomicCounters, mAcbBufferOffsets));
293         substituteArguments.push_back(CreateUIntNode(valueChange));
294 
295         TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
296             kAtomicAddFunction, &substituteArguments, *mSymbolTable, 310);
297 
298         // Note that atomicCounterDecrement returns the *new* value instead of the prior value,
299         // unlike atomicAdd.  So we need to do a -1 on the result as well.
300         if (isDecrement)
301         {
302             substituteCall = new TIntermBinary(EOpSub, substituteCall, CreateUIntNode(1));
303         }
304 
305         queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
306         return true;
307     }
308 
309     const TVariable *mAtomicCounters;
310     const TIntermTyped *mAcbBufferOffsets;
311 };
312 
313 }  // anonymous namespace
314 
RewriteAtomicCounters(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TIntermTyped * acbBufferOffsets)315 bool RewriteAtomicCounters(TCompiler *compiler,
316                            TIntermBlock *root,
317                            TSymbolTable *symbolTable,
318                            const TIntermTyped *acbBufferOffsets)
319 {
320     const TVariable *atomicCounters = DeclareAtomicCountersBuffers(root, symbolTable);
321 
322     RewriteAtomicCountersTraverser traverser(symbolTable, atomicCounters, acbBufferOffsets);
323     root->traverse(&traverser);
324     return traverser.updateTree(compiler, root);
325 }
326 }  // namespace sh
327