1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteR32fImages: Change images qualified with r32f to use r32ui instead.
7 //
8 
9 #include "compiler/translator/tree_ops/vulkan/RewriteR32fImages.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
IsR32fImage(const TType & type)23 bool IsR32fImage(const TType &type)
24 {
25     return type.getQualifier() == EvqUniform && type.isImage() &&
26            type.getLayoutQualifier().imageInternalFormat == EiifR32F;
27 }
28 
29 using ImageMap = angle::HashMap<const TVariable *, const TVariable *>;
30 
31 TIntermTyped *RewriteBuiltinFunctionCall(TCompiler *compiler,
32                                          TSymbolTable *symbolTable,
33                                          TIntermAggregate *node,
34                                          const ImageMap &imageMap);
35 
36 // Given an expression, this traverser calculates a new expression where builtin function calls to
37 // r32f images are replaced with ones to the mapped r32ui image.  In particular, this is run on the
38 // right node of EOpIndexIndirect binary nodes, so that the expression in the index gets a chance to
39 // go through this transformation.
40 class RewriteExpressionTraverser final : public TIntermTraverser
41 {
42   public:
RewriteExpressionTraverser(TCompiler * compiler,TSymbolTable * symbolTable,const ImageMap & imageMap)43     explicit RewriteExpressionTraverser(TCompiler *compiler,
44                                         TSymbolTable *symbolTable,
45                                         const ImageMap &imageMap)
46         : TIntermTraverser(true, false, false, symbolTable),
47           mCompiler(compiler),
48           mImageMap(imageMap)
49     {}
50 
visitAggregate(Visit visit,TIntermAggregate * node)51     bool visitAggregate(Visit visit, TIntermAggregate *node) override
52     {
53         TIntermTyped *rewritten =
54             RewriteBuiltinFunctionCall(mCompiler, mSymbolTable, node, mImageMap);
55         if (rewritten == nullptr)
56         {
57             return true;
58         }
59 
60         queueReplacement(rewritten, OriginalNode::IS_DROPPED);
61 
62         // Don't iterate as the expression is rewritten.
63         return false;
64     }
65 
66   private:
67     TCompiler *mCompiler;
68 
69     const ImageMap &mImageMap;
70 };
71 
72 // Rewrite the index of an EOpIndexIndirect expression as well as any arguments to the builtin
73 // function call.
RewriteExpression(TCompiler * compiler,TSymbolTable * symbolTable,TIntermTyped * expression,const ImageMap & imageMap)74 TIntermTyped *RewriteExpression(TCompiler *compiler,
75                                 TSymbolTable *symbolTable,
76                                 TIntermTyped *expression,
77                                 const ImageMap &imageMap)
78 {
79     // Create a fake block to insert the node in.  The root itself may need changing.
80     TIntermBlock block;
81     block.appendStatement(expression);
82 
83     RewriteExpressionTraverser traverser(compiler, symbolTable, imageMap);
84     block.traverse(&traverser);
85 
86     bool valid = traverser.updateTree(compiler, &block);
87     ASSERT(valid);
88 
89     TIntermTyped *rewritten = block.getChildNode(0)->getAsTyped();
90 
91     return rewritten;
92 }
93 
94 // Given a builtin function call such as the following:
95 //
96 //     imageLoad(expression, ...);
97 //
98 // expression is in the form of:
99 //
100 // - image uniform
101 // - image uniform array indexed with EOpIndexDirect or EOpIndexIndirect.  Note that
102 //   RewriteArrayOfArrayOfOpaqueUniforms has already ensured that the image array is
103 //   single-dimension.
104 //
105 // The latter case (with EOpIndexIndirect) is not valid GLSL (up to GL_EXT_gpu_shader5), but if it
106 // were, the index itself could have contained an image builtin function call, so is recursively
107 // processed (in case supported in future).  Additionally, the other builtin function arguments may
108 // need processing too.
109 //
110 // This function creates a similar expression where the image uniforms (of type r32f) are replaced
111 // with those of r32ui type.
112 //
RewriteBuiltinFunctionCall(TCompiler * compiler,TSymbolTable * symbolTable,TIntermAggregate * node,const ImageMap & imageMap)113 TIntermTyped *RewriteBuiltinFunctionCall(TCompiler *compiler,
114                                          TSymbolTable *symbolTable,
115                                          TIntermAggregate *node,
116                                          const ImageMap &imageMap)
117 {
118     if (!BuiltInGroup::IsBuiltIn(node->getOp()))
119     {
120         // AST functions don't require modification as r32f image function parameters are removed by
121         // MonomorphizeUnsupportedFunctionsInVulkanGLSL.
122         return nullptr;
123     }
124 
125     // If it's an |image*| function, replace the function with an equivalent that uses an r32ui
126     // image.
127     if (!node->getFunction()->isImageFunction())
128     {
129         return nullptr;
130     }
131 
132     TIntermSequence *arguments = node->getSequence();
133 
134     TIntermTyped *imageExpression = (*arguments)[0]->getAsTyped();
135     ASSERT(imageExpression);
136 
137     // Find the image uniform that's being indexed, if indexed.
138     TIntermBinary *asBinary     = imageExpression->getAsBinaryNode();
139     TIntermSymbol *imageUniform = imageExpression->getAsSymbolNode();
140 
141     if (asBinary)
142     {
143         ASSERT(asBinary->getOp() == EOpIndexDirect || asBinary->getOp() == EOpIndexIndirect);
144         imageUniform = asBinary->getLeft()->getAsSymbolNode();
145     }
146 
147     ASSERT(imageUniform);
148     if (!IsR32fImage(imageUniform->getType()))
149     {
150         return nullptr;
151     }
152 
153     ASSERT(imageMap.find(&imageUniform->variable()) != imageMap.end());
154     const TVariable *replacementImage = imageMap.at(&imageUniform->variable());
155 
156     // Build the expression again, with the image uniform replaced.  If index is dynamic,
157     // recursively process it.
158     TIntermTyped *replacementExpression = new TIntermSymbol(replacementImage);
159 
160     // Index it, if indexed.
161     if (asBinary != nullptr)
162     {
163         TIntermTyped *index = asBinary->getRight();
164 
165         switch (asBinary->getOp())
166         {
167             case EOpIndexDirect:
168                 break;
169             case EOpIndexIndirect:
170             {
171                 // Run RewriteExpressionTraverser on the index node.  This case is currently
172                 // impossible with known extensions.
173                 UNREACHABLE();
174                 index = RewriteExpression(compiler, symbolTable, index, imageMap);
175                 break;
176             }
177             default:
178                 UNREACHABLE();
179                 break;
180         }
181 
182         replacementExpression = new TIntermBinary(asBinary->getOp(), replacementExpression, index);
183     }
184 
185     TIntermSequence substituteArguments;
186     substituteArguments.push_back(replacementExpression);
187 
188     for (size_t argIndex = 1; argIndex < arguments->size(); ++argIndex)
189     {
190         TIntermTyped *arg = (*arguments)[argIndex]->getAsTyped();
191 
192         // Run RewriteExpressionTraverser on the argument.  It may itself be an expression with an
193         // r32f image that needs to be rewritten.
194         arg = RewriteExpression(compiler, symbolTable, arg, imageMap);
195         substituteArguments.push_back(arg);
196     }
197 
198     const ImmutableString &functionName = node->getFunction()->name();
199     bool isImageAtomicExchange          = functionName == "imageAtomicExchange";
200     bool isImageLoad                    = false;
201 
202     if (functionName == "imageStore" || isImageAtomicExchange)
203     {
204         // The last parameter is float data, which should be changed to floatBitsToUint(data).
205         TIntermTyped *data = substituteArguments.back()->getAsTyped();
206         substituteArguments.back() =
207             CreateBuiltInUnaryFunctionCallNode("floatBitsToUint", data, *symbolTable, 300);
208     }
209     else if (functionName == "imageLoad")
210     {
211         isImageLoad = true;
212     }
213     else
214     {
215         // imageSize does not have any other arguments.
216         ASSERT(functionName == "imageSize");
217         ASSERT(arguments->size() == 1);
218     }
219 
220     TIntermTyped *replacementCall =
221         CreateBuiltInFunctionCallNode(functionName.data(), &substituteArguments, *symbolTable, 310);
222 
223     // If imageLoad or imageAtomicExchange, the result is now uint, which should be converted with
224     // uintBitsToFloat.  With imageLoad, the alpha channel should always read 1.0 regardless.
225     if (isImageLoad || isImageAtomicExchange)
226     {
227         if (isImageLoad)
228         {
229             // imageLoad().rgb
230             replacementCall = new TIntermSwizzle(replacementCall, {0, 1, 2});
231         }
232 
233         // uintBitsToFloat(imageLoad().rgb), or uintBitsToFloat(imageAtomicExchange())
234         replacementCall = CreateBuiltInUnaryFunctionCallNode("uintBitsToFloat", replacementCall,
235                                                              *symbolTable, 300);
236 
237         if (isImageLoad)
238         {
239             // vec4(uintBitsToFloat(imageLoad().rgb), 1.0)
240             const TType &vec4Type           = *StaticType::GetBasic<EbtFloat, 4>();
241             TIntermSequence constructorArgs = {replacementCall, CreateFloatNode(1.0f)};
242             replacementCall = TIntermAggregate::CreateConstructor(vec4Type, &constructorArgs);
243         }
244     }
245 
246     return replacementCall;
247 }
248 
249 // Traverser that:
250 //
251 // 1. Converts the layout(r32f, ...) ... image* name; declarations to use the r32ui format
252 // 2. Converts |imageLoad| and |imageStore| functions to use |uintBitsToFloat| and |floatBitsToUint|
253 //    respectively.
254 // 3. Converts |imageAtomicExchange| to use |floatBitsToUint| and |uintBitsToFloat|.
255 class RewriteR32fImagesTraverser : public TIntermTraverser
256 {
257   public:
RewriteR32fImagesTraverser(TCompiler * compiler,TSymbolTable * symbolTable)258     RewriteR32fImagesTraverser(TCompiler *compiler, TSymbolTable *symbolTable)
259         : TIntermTraverser(true, false, false, symbolTable), mCompiler(compiler)
260     {}
261 
visitDeclaration(Visit visit,TIntermDeclaration * node)262     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
263     {
264         if (visit != PreVisit)
265         {
266             return true;
267         }
268 
269         const TIntermSequence &sequence = *(node->getSequence());
270 
271         TIntermTyped *declVariable = sequence.front()->getAsTyped();
272         const TType &type          = declVariable->getType();
273 
274         if (!IsR32fImage(type))
275         {
276             return true;
277         }
278 
279         TIntermSymbol *oldSymbol = declVariable->getAsSymbolNode();
280         ASSERT(oldSymbol != nullptr);
281 
282         const TVariable &oldVariable = oldSymbol->variable();
283 
284         TType *newType                      = new TType(type);
285         TLayoutQualifier layoutQualifier    = type.getLayoutQualifier();
286         layoutQualifier.imageInternalFormat = EiifR32UI;
287         newType->setLayoutQualifier(layoutQualifier);
288 
289         switch (type.getBasicType())
290         {
291             case EbtImage2D:
292                 newType->setBasicType(EbtUImage2D);
293                 break;
294             case EbtImage3D:
295                 newType->setBasicType(EbtUImage3D);
296                 break;
297             case EbtImage2DArray:
298                 newType->setBasicType(EbtUImage2DArray);
299                 break;
300             case EbtImageCube:
301                 newType->setBasicType(EbtUImageCube);
302                 break;
303             case EbtImage1D:
304                 newType->setBasicType(EbtUImage1D);
305                 break;
306             case EbtImage1DArray:
307                 newType->setBasicType(EbtUImage1DArray);
308                 break;
309             case EbtImage2DMS:
310                 newType->setBasicType(EbtUImage2DMS);
311                 break;
312             case EbtImage2DMSArray:
313                 newType->setBasicType(EbtUImage2DMSArray);
314                 break;
315             case EbtImageCubeArray:
316                 newType->setBasicType(EbtUImageCubeArray);
317                 break;
318             case EbtImageRect:
319                 newType->setBasicType(EbtUImageRect);
320                 break;
321             case EbtImageBuffer:
322                 newType->setBasicType(EbtUImageBuffer);
323                 break;
324             default:
325                 UNREACHABLE();
326         }
327 
328         TVariable *newVariable =
329             new TVariable(oldVariable.uniqueId(), oldVariable.name(), oldVariable.symbolType(),
330                           oldVariable.extensions(), newType);
331 
332         mImageMap[&oldVariable] = newVariable;
333 
334         TIntermDeclaration *newDecl = new TIntermDeclaration();
335         newDecl->appendDeclarator(new TIntermSymbol(newVariable));
336 
337         queueReplacement(newDecl, OriginalNode::IS_DROPPED);
338 
339         return false;
340     }
341 
342     // Same implementation as in RewriteExpressionTraverser.  That traverser cannot replace root.
visitAggregate(Visit visit,TIntermAggregate * node)343     bool visitAggregate(Visit visit, TIntermAggregate *node) override
344     {
345         TIntermTyped *rewritten =
346             RewriteBuiltinFunctionCall(mCompiler, mSymbolTable, node, mImageMap);
347         if (rewritten == nullptr)
348         {
349             return true;
350         }
351 
352         queueReplacement(rewritten, OriginalNode::IS_DROPPED);
353 
354         return false;
355     }
356 
visitSymbol(TIntermSymbol * symbol)357     void visitSymbol(TIntermSymbol *symbol) override
358     {
359         // Cannot encounter the image symbol directly.  It can only be used with built-in functions,
360         // and therefore it's handled by visitAggregate.
361         ASSERT(!IsR32fImage(symbol->getType()));
362     }
363 
364   private:
365     TCompiler *mCompiler;
366 
367     // Map from r32f image to r32ui image
368     ImageMap mImageMap;
369 };
370 
371 }  // anonymous namespace
372 
RewriteR32fImages(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)373 bool RewriteR32fImages(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
374 {
375     RewriteR32fImagesTraverser traverser(compiler, symbolTable);
376     root->traverse(&traverser);
377     return traverser.updateTree(compiler, root);
378 }
379 }  // namespace sh
380