1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteAtomicCounters: Emulate atomic counter buffers with storage buffers.
7 //
8
9 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
10
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18
19 namespace sh
20 {
21 namespace
22 {
23 constexpr ImmutableString kAtomicCountersVarName = ImmutableString("atomicCounters");
24 constexpr ImmutableString kAtomicCounterFieldName = ImmutableString("counters");
25
26 // DeclareAtomicCountersBuffer adds a storage buffer array that's used with atomic counters.
DeclareAtomicCountersBuffers(TIntermBlock * root,TSymbolTable * symbolTable)27 const TVariable *DeclareAtomicCountersBuffers(TIntermBlock *root, TSymbolTable *symbolTable)
28 {
29 // Define `uint counters[];` as the only field in the interface block.
30 TFieldList *fieldList = new TFieldList;
31 TType *counterType = new TType(EbtUInt);
32 counterType->makeArray(0);
33
34 TField *countersField =
35 new TField(counterType, kAtomicCounterFieldName, TSourceLoc(), SymbolType::AngleInternal);
36
37 fieldList->push_back(countersField);
38
39 TMemoryQualifier coherentMemory = TMemoryQualifier::Create();
40 coherentMemory.coherent = true;
41
42 // There are a maximum of 8 atomic counter buffers per IMPLEMENTATION_MAX_ATOMIC_COUNTER_BUFFERS
43 // in libANGLE/Constants.h.
44 constexpr uint32_t kMaxAtomicCounterBuffers = 8;
45
46 // Define a storage block "ANGLEAtomicCounters" with instance name "atomicCounters".
47 return DeclareInterfaceBlock(
48 root, symbolTable, fieldList, EvqBuffer, TLayoutQualifier::Create(), coherentMemory,
49 kMaxAtomicCounterBuffers, ImmutableString(vk::kAtomicCountersBlockName),
50 kAtomicCountersVarName);
51 }
52
CreateUniformBufferOffset(const TIntermTyped * uniformBufferOffsets,int binding)53 TIntermTyped *CreateUniformBufferOffset(const TIntermTyped *uniformBufferOffsets, int binding)
54 {
55 // Each uint in the |acbBufferOffsets| uniform contains offsets for 4 bindings. Therefore, the
56 // expression to get the uniform offset for the binding is:
57 //
58 // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
59
60 // acbBufferOffsets[binding / 4]
61 TIntermBinary *uniformBufferOffsetUint = new TIntermBinary(
62 EOpIndexDirect, uniformBufferOffsets->deepCopy(), CreateIndexNode(binding / 4));
63
64 // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8)
65 TIntermBinary *uniformBufferOffsetShifted = uniformBufferOffsetUint;
66 if (binding % 4 != 0)
67 {
68 uniformBufferOffsetShifted = new TIntermBinary(EOpBitShiftRight, uniformBufferOffsetUint,
69 CreateUIntNode((binding % 4) * 8));
70 }
71
72 // acbBufferOffsets[binding / 4] >> ((binding % 4) * 8) & 0xFF
73 return new TIntermBinary(EOpBitwiseAnd, uniformBufferOffsetShifted, CreateUIntNode(0xFF));
74 }
75
CreateAtomicCounterRef(TIntermTyped * atomicCounterExpression,const TVariable * atomicCounters,const TIntermTyped * uniformBufferOffsets)76 TIntermBinary *CreateAtomicCounterRef(TIntermTyped *atomicCounterExpression,
77 const TVariable *atomicCounters,
78 const TIntermTyped *uniformBufferOffsets)
79 {
80 // The atomic counters storage buffer declaration looks as such:
81 //
82 // layout(...) buffer ANGLEAtomicCounters
83 // {
84 // uint counters[];
85 // } atomicCounters[N];
86 //
87 // Where N is large enough to accommodate atomic counter buffer bindings used in the shader.
88 //
89 // This function takes an expression that uses an atomic counter, which can either be:
90 //
91 // - ac
92 // - acArray[index]
93 //
94 // Note that RewriteArrayOfArrayOfOpaqueUniforms has already flattened array of array of atomic
95 // counters.
96 //
97 // For the first case (ac), the following code is generated:
98 //
99 // atomicCounters[binding].counters[offset]
100 //
101 // For the second case (acArray[index]), the following code is generated:
102 //
103 // atomicCounters[binding].counters[offset + index]
104 //
105 // In either case, an offset given through uniforms is also added to |offset|. The binding is
106 // necessarily a constant thanks to MonomorphizeUnsupportedFunctionsInVulkanGLSL.
107
108 // First determine if there's an index, and extract the atomic counter symbol out of the
109 // expression.
110 TIntermSymbol *atomicCounterSymbol = atomicCounterExpression->getAsSymbolNode();
111 TIntermTyped *atomicCounterIndex = nullptr;
112 int atomicCounterConstIndex = 0;
113 TIntermBinary *asBinary = atomicCounterExpression->getAsBinaryNode();
114 if (asBinary != nullptr)
115 {
116 atomicCounterSymbol = asBinary->getLeft()->getAsSymbolNode();
117
118 switch (asBinary->getOp())
119 {
120 case EOpIndexDirect:
121 atomicCounterConstIndex = asBinary->getRight()->getAsConstantUnion()->getIConst(0);
122 break;
123 case EOpIndexIndirect:
124 atomicCounterIndex = asBinary->getRight();
125 break;
126 default:
127 UNREACHABLE();
128 }
129 }
130
131 // Extract binding and offset information out of the atomic counter symbol.
132 ASSERT(atomicCounterSymbol);
133 const TVariable *atomicCounterVar = &atomicCounterSymbol->variable();
134 const TType &atomicCounterType = atomicCounterVar->getType();
135
136 const int binding = atomicCounterType.getLayoutQualifier().binding;
137 int offset = atomicCounterType.getLayoutQualifier().offset / 4;
138
139 // Create the expression:
140 //
141 // offset + arrayIndex + uniformOffset
142 //
143 // If arrayIndex is a constant, it's added with offset right here.
144
145 offset += atomicCounterConstIndex;
146
147 TIntermTyped *index = CreateUniformBufferOffset(uniformBufferOffsets, binding);
148 if (atomicCounterIndex != nullptr)
149 {
150 index = new TIntermBinary(EOpAdd, index, atomicCounterIndex);
151 }
152 if (offset != 0)
153 {
154 index = new TIntermBinary(EOpAdd, index, CreateIndexNode(offset));
155 }
156
157 // Finally, create the complete expression:
158 //
159 // atomicCounters[binding].counters[index]
160
161 TIntermSymbol *atomicCountersRef = new TIntermSymbol(atomicCounters);
162
163 // atomicCounters[binding]
164 TIntermBinary *countersBlock =
165 new TIntermBinary(EOpIndexDirect, atomicCountersRef, CreateIndexNode(binding));
166
167 // atomicCounters[binding].counters
168 TIntermBinary *counters =
169 new TIntermBinary(EOpIndexDirectInterfaceBlock, countersBlock, CreateIndexNode(0));
170
171 return new TIntermBinary(EOpIndexIndirect, counters, index);
172 }
173
174 // Traverser that:
175 //
176 // 1. Removes the |uniform atomic_uint| declarations and remembers the binding and offset.
177 // 2. Substitutes |atomicVar[n]| with |buffer[binding].counters[offset + n]|.
178 class RewriteAtomicCountersTraverser : public TIntermTraverser
179 {
180 public:
RewriteAtomicCountersTraverser(TSymbolTable * symbolTable,const TVariable * atomicCounters,const TIntermTyped * acbBufferOffsets)181 RewriteAtomicCountersTraverser(TSymbolTable *symbolTable,
182 const TVariable *atomicCounters,
183 const TIntermTyped *acbBufferOffsets)
184 : TIntermTraverser(true, false, false, symbolTable),
185 mAtomicCounters(atomicCounters),
186 mAcbBufferOffsets(acbBufferOffsets)
187 {}
188
visitDeclaration(Visit visit,TIntermDeclaration * node)189 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
190 {
191 if (!mInGlobalScope)
192 {
193 return true;
194 }
195
196 const TIntermSequence &sequence = *(node->getSequence());
197
198 TIntermTyped *variable = sequence.front()->getAsTyped();
199 const TType &type = variable->getType();
200 bool isAtomicCounter = type.isAtomicCounter();
201
202 if (isAtomicCounter)
203 {
204 ASSERT(type.getQualifier() == EvqUniform);
205 TIntermSequence emptySequence;
206 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
207 std::move(emptySequence));
208
209 return false;
210 }
211
212 return true;
213 }
214
visitAggregate(Visit visit,TIntermAggregate * node)215 bool visitAggregate(Visit visit, TIntermAggregate *node) override
216 {
217 if (BuiltInGroup::IsBuiltIn(node->getOp()))
218 {
219 bool converted = convertBuiltinFunction(node);
220 return !converted;
221 }
222
223 // AST functions don't require modification as atomic counter function parameters are
224 // removed by MonomorphizeUnsupportedFunctionsInVulkanGLSL.
225 return true;
226 }
227
visitSymbol(TIntermSymbol * symbol)228 void visitSymbol(TIntermSymbol *symbol) override
229 {
230 // Cannot encounter the atomic counter symbol directly. It can only be used with functions,
231 // and therefore it's handled by visitAggregate.
232 ASSERT(!symbol->getType().isAtomicCounter());
233 }
234
visitBinary(Visit visit,TIntermBinary * node)235 bool visitBinary(Visit visit, TIntermBinary *node) override
236 {
237 // Cannot encounter an atomic counter expression directly. It can only be used with
238 // functions, and therefore it's handled by visitAggregate.
239 ASSERT(!node->getType().isAtomicCounter());
240 return true;
241 }
242
243 private:
convertBuiltinFunction(TIntermAggregate * node)244 bool convertBuiltinFunction(TIntermAggregate *node)
245 {
246 const TOperator op = node->getOp();
247
248 // If the function is |memoryBarrierAtomicCounter|, simply replace it with
249 // |memoryBarrierBuffer|.
250 if (op == EOpMemoryBarrierAtomicCounter)
251 {
252 TIntermSequence emptySequence;
253 TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
254 "memoryBarrierBuffer", &emptySequence, *mSymbolTable, 310);
255 queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
256 return true;
257 }
258
259 // If it's an |atomicCounter*| function, replace the function with an |atomic*| equivalent.
260 if (!node->getFunction()->isAtomicCounterFunction())
261 {
262 return false;
263 }
264
265 // Note: atomicAdd(0) is used for atomic reads.
266 uint32_t valueChange = 0;
267 constexpr char kAtomicAddFunction[] = "atomicAdd";
268 bool isDecrement = false;
269
270 if (op == EOpAtomicCounterIncrement)
271 {
272 valueChange = 1;
273 }
274 else if (op == EOpAtomicCounterDecrement)
275 {
276 // uint values are required to wrap around, so 0xFFFFFFFFu is used as -1.
277 valueChange = std::numeric_limits<uint32_t>::max();
278 static_assert(static_cast<uint32_t>(-1) == std::numeric_limits<uint32_t>::max(),
279 "uint32_t max is not -1");
280
281 isDecrement = true;
282 }
283 else
284 {
285 ASSERT(op == EOpAtomicCounter);
286 }
287
288 TIntermTyped *param = (*node->getSequence())[0]->getAsTyped();
289
290 TIntermSequence substituteArguments;
291 substituteArguments.push_back(
292 CreateAtomicCounterRef(param, mAtomicCounters, mAcbBufferOffsets));
293 substituteArguments.push_back(CreateUIntNode(valueChange));
294
295 TIntermTyped *substituteCall = CreateBuiltInFunctionCallNode(
296 kAtomicAddFunction, &substituteArguments, *mSymbolTable, 310);
297
298 // Note that atomicCounterDecrement returns the *new* value instead of the prior value,
299 // unlike atomicAdd. So we need to do a -1 on the result as well.
300 if (isDecrement)
301 {
302 substituteCall = new TIntermBinary(EOpSub, substituteCall, CreateUIntNode(1));
303 }
304
305 queueReplacement(substituteCall, OriginalNode::IS_DROPPED);
306 return true;
307 }
308
309 const TVariable *mAtomicCounters;
310 const TIntermTyped *mAcbBufferOffsets;
311 };
312
313 } // anonymous namespace
314
RewriteAtomicCounters(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TIntermTyped * acbBufferOffsets)315 bool RewriteAtomicCounters(TCompiler *compiler,
316 TIntermBlock *root,
317 TSymbolTable *symbolTable,
318 const TIntermTyped *acbBufferOffsets)
319 {
320 const TVariable *atomicCounters = DeclareAtomicCountersBuffers(root, symbolTable);
321
322 RewriteAtomicCountersTraverser traverser(symbolTable, atomicCounters, acbBufferOffsets);
323 root->traverse(&traverser);
324 return traverser.updateTree(compiler, root);
325 }
326 } // namespace sh
327