1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // MonomorphizeUnsupportedFunctionsInVulkanGLSL: Monomorphize functions that are called with
7 // parameters that are not compatible with Vulkan GLSL.
8 //
9 
10 #include "compiler/translator/tree_ops/vulkan/MonomorphizeUnsupportedFunctionsInVulkanGLSL.h"
11 
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
23 struct Argument
24 {
25     size_t argumentIndex;
26     TIntermTyped *argument;
27 };
28 
29 struct FunctionData
30 {
31     // Whether the original function is used.  If this is false, the function can be removed because
32     // all callers have been modified.
33     bool isOriginalUsed;
34     // The original definition of the function, used to create the monomorphized version.
35     TIntermFunctionDefinition *originalDefinition;
36     // List of monomorphized versions of this function.  They will be added next to the original
37     // version (or replace it).
38     TVector<TIntermFunctionDefinition *> monomorphizedDefinitions;
39 };
40 
41 using FunctionMap = angle::HashMap<const TFunction *, FunctionData>;
42 
43 // Traverse the function definitions and initialize the map.  Allows visitAggregate to have access
44 // to TIntermFunctionDefinition even when the function is only forward declared at that point.
InitializeFunctionMap(TIntermBlock * root,FunctionMap * functionMapOut)45 void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut)
46 {
47     TIntermSequence &sequence = *root->getSequence();
48 
49     for (TIntermNode *node : sequence)
50     {
51         TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition();
52         if (asFuncDef != nullptr)
53         {
54             const TFunction *function = asFuncDef->getFunction();
55             ASSERT(function && functionMapOut->find(function) == functionMapOut->end());
56             (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}};
57         }
58     }
59 }
60 
GetBaseUniform(TIntermTyped * node,bool * isSamplerInStructOut)61 const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut)
62 {
63     *isSamplerInStructOut = false;
64 
65     while (node->getAsBinaryNode())
66     {
67         TIntermBinary *asBinary = node->getAsBinaryNode();
68 
69         TOperator op = asBinary->getOp();
70 
71         // No opaque uniform can be inside an interface block.
72         if (op == EOpIndexDirectInterfaceBlock)
73         {
74             return nullptr;
75         }
76 
77         if (op == EOpIndexDirectStruct)
78         {
79             *isSamplerInStructOut = true;
80         }
81 
82         node = asBinary->getLeft();
83     }
84 
85     // Only interested in uniform opaque types.  If a function call within another function uses
86     // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the
87     // calling function is monomorphized.
88     if (node->getType().getQualifier() != EvqUniform)
89     {
90         return nullptr;
91     }
92 
93     ASSERT(IsOpaqueType(node->getType().getBasicType()) ||
94            node->getType().isStructureContainingSamplers());
95 
96     TIntermSymbol *asSymbol = node->getAsSymbolNode();
97     ASSERT(asSymbol);
98 
99     return &asSymbol->variable();
100 }
101 
ExtractSideEffects(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * replacementIndices)102 TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable,
103                                  TIntermTyped *node,
104                                  TIntermSequence *replacementIndices)
105 {
106     TIntermTyped *withoutSideEffects = node->deepCopy();
107 
108     for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary;
109          asBinary                = asBinary->getLeft()->getAsBinaryNode())
110     {
111         TOperator op        = asBinary->getOp();
112         TIntermTyped *index = asBinary->getRight();
113 
114         if (op == EOpIndexDirectStruct)
115         {
116             break;
117         }
118 
119         // No side effects with constant expressions.
120         if (op == EOpIndexDirect)
121         {
122             ASSERT(index->getAsConstantUnion());
123             continue;
124         }
125 
126         ASSERT(op == EOpIndexIndirect);
127 
128         // If the index is a symbol, there's no side effect, so leave it as-is.
129         if (index->getAsSymbolNode())
130         {
131             continue;
132         }
133 
134         // Otherwise create a temp variable initialized with the index and use that temp variable as
135         // the index.
136         TIntermDeclaration *tempDecl = nullptr;
137         TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl);
138 
139         replacementIndices->push_back(tempDecl);
140         asBinary->replaceChildNode(index, new TIntermSymbol(tempVar));
141     }
142 
143     return withoutSideEffects;
144 }
145 
CreateMonomorphizedFunctionCallArgs(const TIntermSequence & originalCallArguments,const TVector<Argument> & replacedArguments,TIntermSequence * substituteArgsOut)146 void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments,
147                                          const TVector<Argument> &replacedArguments,
148                                          TIntermSequence *substituteArgsOut)
149 {
150     size_t nextReplacedArg = 0;
151     for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex)
152     {
153         if (nextReplacedArg >= replacedArguments.size() ||
154             argIndex != replacedArguments[nextReplacedArg].argumentIndex)
155         {
156             // Not replaced, keep argument as is.
157             substituteArgsOut->push_back(originalCallArguments[argIndex]);
158         }
159         else
160         {
161             TIntermTyped *argument = replacedArguments[nextReplacedArg].argument;
162 
163             // Iterate over indices of the argument and create a new arg for every non-const
164             // index.  Note that the index itself may be an expression, and it may require further
165             // substitution in the next pass.
166             while (argument->getAsBinaryNode())
167             {
168                 TIntermBinary *asBinary = argument->getAsBinaryNode();
169                 if (asBinary->getOp() == EOpIndexIndirect)
170                 {
171                     TIntermTyped *index = asBinary->getRight();
172                     substituteArgsOut->push_back(index->deepCopy());
173                 }
174                 argument = asBinary->getLeft();
175             }
176 
177             ++nextReplacedArg;
178         }
179     }
180 }
181 
MonomorphizeFunction(TSymbolTable * symbolTable,const TFunction * original,TVector<Argument> * replacedArguments,VariableReplacementMap * argumentMapOut)182 const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable,
183                                       const TFunction *original,
184                                       TVector<Argument> *replacedArguments,
185                                       VariableReplacementMap *argumentMapOut)
186 {
187     TFunction *substituteFunction =
188         new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
189                       &original->getReturnType(), original->isKnownToNotHaveSideEffects());
190 
191     size_t nextReplacedArg = 0;
192     for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex)
193     {
194         const TVariable *originalParam = original->getParam(paramIndex);
195 
196         if (nextReplacedArg >= replacedArguments->size() ||
197             paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex)
198         {
199             TVariable *substituteArgument =
200                 new TVariable(symbolTable, originalParam->name(), &originalParam->getType(),
201                               originalParam->symbolType());
202             // Not replaced, add an identical parameter.
203             substituteFunction->addParameter(substituteArgument);
204             (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument);
205         }
206         else
207         {
208             TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument;
209             (*argumentMapOut)[originalParam] = substituteArgument;
210 
211             // Iterate over indices of the argument and create a new parameter for every non-const
212             // index (which may be an expression).  Replace the symbol in the argument with a
213             // variable of the index type.  This is later used to replace the parameter in the
214             // function body.
215             while (substituteArgument->getAsBinaryNode())
216             {
217                 TIntermBinary *asBinary = substituteArgument->getAsBinaryNode();
218                 if (asBinary->getOp() == EOpIndexIndirect)
219                 {
220                     TIntermTyped *index = asBinary->getRight();
221                     TType *indexType    = new TType(index->getType());
222                     indexType->setQualifier(EvqIn);
223 
224                     TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType,
225                                                      SymbolType::AngleInternal);
226                     substituteFunction->addParameter(param);
227 
228                     // The argument now uses the function parameters as indices.
229                     asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param));
230                 }
231                 substituteArgument = asBinary->getLeft();
232             }
233 
234             ++nextReplacedArg;
235         }
236     }
237 
238     return substituteFunction;
239 }
240 
241 class MonomorphizeTraverser final : public TIntermTraverser
242 {
243   public:
MonomorphizeTraverser(TCompiler * compiler,TSymbolTable * symbolTable,ShCompileOptions compileOptions,FunctionMap * functionMap)244     explicit MonomorphizeTraverser(TCompiler *compiler,
245                                    TSymbolTable *symbolTable,
246                                    ShCompileOptions compileOptions,
247                                    FunctionMap *functionMap)
248         : TIntermTraverser(true, false, false, symbolTable),
249           mCompiler(compiler),
250           mCompileOptions(compileOptions),
251           mFunctionMap(functionMap)
252     {}
253 
visitAggregate(Visit visit,TIntermAggregate * node)254     bool visitAggregate(Visit visit, TIntermAggregate *node) override
255     {
256         if (node->getOp() != EOpCallFunctionInAST)
257         {
258             return true;
259         }
260 
261         const TFunction *function = node->getFunction();
262         ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end());
263 
264         FunctionData &data = (*mFunctionMap)[function];
265 
266         TIntermFunctionDefinition *monomorphized =
267             processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed);
268         if (monomorphized)
269         {
270             data.monomorphizedDefinitions.push_back(monomorphized);
271         }
272 
273         return true;
274     }
275 
getAnyMonomorphized() const276     bool getAnyMonomorphized() const { return mAnyMonomorphized; }
277 
278   private:
processFunctionCall(TIntermAggregate * functionCall,TIntermFunctionDefinition * originalDefinition,bool * isOriginalUsedOut)279     TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall,
280                                                    TIntermFunctionDefinition *originalDefinition,
281                                                    bool *isOriginalUsedOut)
282     {
283         const TFunction *function            = functionCall->getFunction();
284         const TIntermSequence &callArguments = *functionCall->getSequence();
285 
286         TVector<Argument> replacedArguments;
287         TIntermSequence replacementIndices;
288 
289         // Go through function call arguments, and see if any is used in an unsupported way.
290         for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex)
291         {
292             TIntermTyped *callArgument    = callArguments[argIndex]->getAsTyped();
293             const TVariable *funcArgument = function->getParam(argIndex);
294 
295             // Only interested in opaque uniforms and structs that contain samplers.
296             const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType());
297             const bool isStructContainingSamplers =
298                 funcArgument->getType().isStructureContainingSamplers();
299             if (!isOpaqueType && !isStructContainingSamplers)
300             {
301                 continue;
302             }
303 
304             // If not uniform (the variable was itself a function parameter), don't process it in
305             // this pass, as we don't know which actual uniform it corresponds to.
306             bool isSamplerInStruct   = false;
307             const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct);
308             if (uniform == nullptr)
309             {
310                 continue;
311             }
312 
313             // Conditions for monomorphization:
314             //
315             // - If the parameter is a structure that contains samplers (so in RewriteStructSamplers
316             //   we don't need to rewrite the functions to accept multiple parameters split from the
317             //   struct), or
318             // - If the opaque uniform is a sampler in a struct (which can create an array-of-array
319             //   situation), and the function expects an array of samplers, or
320             // - If the opaque uniform is an array of array of sampler or image, and it's partially
321             //   subscripted (i.e. the function itself expects an array), or
322             // - The opaque uniform is an atomic counter
323             // - The opaque uniform is a samplerCube and ES2's cube sampling emulation is requested.
324             // - The opaque uniform is an image* with r32f format.
325             //
326             const TType &type = uniform->getType();
327             const bool isArrayOfArrayOfSamplerOrImage =
328                 (type.isSampler() || type.isImage()) && type.isArrayOfArrays();
329             const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray();
330             const bool isAtomicCounter              = type.isAtomicCounter();
331             const bool isSamplerCubeEmulation =
332                 type.isSamplerCube() &&
333                 (mCompileOptions & SH_EMULATE_SEAMFUL_CUBE_MAP_SAMPLING) != 0;
334             const bool isR32fImage =
335                 type.isImage() && type.getLayoutQualifier().imageInternalFormat == EiifR32F;
336 
337             if (!(isStructContainingSamplers ||
338                   (isSamplerInStruct && isParameterArrayOfOpaqueType) ||
339                   (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType) ||
340                   isAtomicCounter || isSamplerCubeEmulation || isR32fImage))
341             {
342                 continue;
343             }
344 
345             // Copy the argument and extract the side effects.
346             TIntermTyped *argument =
347                 ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices);
348 
349             replacedArguments.push_back({argIndex, argument});
350         }
351 
352         if (replacedArguments.empty())
353         {
354             *isOriginalUsedOut = true;
355             return nullptr;
356         }
357 
358         mAnyMonomorphized = true;
359 
360         insertStatementsInParentBlock(replacementIndices);
361 
362         // Create the arguments for the substitute function call.  Done before monomorphizing the
363         // function, which transforms the arguments to what needs to be replaced in the function
364         // body.
365         TIntermSequence newCallArgs;
366         CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs);
367 
368         // Duplicate the function and substitute the replaced arguments with only the non-const
369         // indices.  Additionally, substitute the non-const indices of arguments with the new
370         // function parameters.
371         VariableReplacementMap argumentMap;
372         const TFunction *monomorphized =
373             MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap);
374 
375         // Replace this function call with a call to the new one.
376         queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs),
377                          OriginalNode::IS_DROPPED);
378 
379         // Create a new function definition, with the body of the old function but with the replaced
380         // parameters substituted with the calling expressions.
381         TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized);
382         TIntermBlock *substituteBlock                 = originalDefinition->getBody()->deepCopy();
383         GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap);
384         bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap);
385         ASSERT(valid);
386 
387         return new TIntermFunctionDefinition(substitutePrototype, substituteBlock);
388     }
389 
390     TCompiler *mCompiler;
391     ShCompileOptions mCompileOptions;
392     bool mAnyMonomorphized = false;
393 
394     // Map of original to monomorphized functions.
395     FunctionMap *mFunctionMap;
396 };
397 
398 class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser
399 {
400   public:
UpdateFunctionsDefinitionsTraverser(TSymbolTable * symbolTable,const FunctionMap & functionMap)401     explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable,
402                                                  const FunctionMap &functionMap)
403         : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap)
404     {}
405 
visitFunctionPrototype(TIntermFunctionPrototype * node)406     void visitFunctionPrototype(TIntermFunctionPrototype *node) override
407     {
408         const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr;
409         if (isInFunctionDefinition)
410         {
411             return;
412         }
413 
414         // Add to and possibly replace the function prototype with replacement prototypes.
415         const TFunction *function = node->getFunction();
416         ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
417 
418         const FunctionData &data = mFunctionMap.at(function);
419 
420         // If nothing to do, leave it be.
421         if (data.monomorphizedDefinitions.empty())
422         {
423             ASSERT(data.isOriginalUsed);
424             return;
425         }
426 
427         // Replace the prototype with itself (if function is still used) as well as any
428         // monomorphized versions.
429         TIntermSequence replacement;
430         if (data.isOriginalUsed)
431         {
432             replacement.push_back(node);
433         }
434         for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
435         {
436             replacement.push_back(new TIntermFunctionPrototype(
437                 monomorphizedDefinition->getFunctionPrototype()->getFunction()));
438         }
439         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
440                                         std::move(replacement));
441     }
442 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)443     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
444     {
445         // Add to and possibly replace the function definition with replacement definitions.
446         const TFunction *function = node->getFunction();
447         ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
448 
449         const FunctionData &data = mFunctionMap.at(function);
450 
451         // If nothing to do, leave it be.
452         if (data.monomorphizedDefinitions.empty())
453         {
454             ASSERT(data.isOriginalUsed || function->name() == "main");
455             return false;
456         }
457 
458         // Replace the definition with itself (if function is still used) as well as any
459         // monomorphized versions.
460         TIntermSequence replacement;
461         if (data.isOriginalUsed)
462         {
463             replacement.push_back(node);
464         }
465         for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
466         {
467             replacement.push_back(monomorphizedDefinition);
468         }
469         mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
470                                         std::move(replacement));
471 
472         return false;
473     }
474 
475   private:
476     const FunctionMap &mFunctionMap;
477 };
478 
SortDeclarations(TIntermBlock * root)479 void SortDeclarations(TIntermBlock *root)
480 {
481     TIntermSequence *original = root->getSequence();
482 
483     TIntermSequence replacement;
484     TIntermSequence functionDefs;
485 
486     // Accumulate non-function-definition declarations in |replacement| and function definitions in
487     // |functionDefs|.
488     for (TIntermNode *node : *original)
489     {
490         if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode())
491         {
492             functionDefs.push_back(node);
493         }
494         else
495         {
496             replacement.push_back(node);
497         }
498     }
499 
500     // Append function definitions to |replacement|.
501     replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end());
502 
503     // Replace root's sequence with |replacement|.
504     root->replaceAllChildren(replacement);
505 }
506 }  // anonymous namespace
507 
MonomorphizeUnsupportedFunctionsInVulkanGLSL(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,ShCompileOptions compileOptions)508 bool MonomorphizeUnsupportedFunctionsInVulkanGLSL(TCompiler *compiler,
509                                                   TIntermBlock *root,
510                                                   TSymbolTable *symbolTable,
511                                                   ShCompileOptions compileOptions)
512 {
513     // First, sort out the declarations such that all non-function declarations are placed before
514     // function definitions.  This way when the function is replaced with one that references said
515     // declarations (i.e. uniforms), the uniform declaration is already present above it.
516     SortDeclarations(root);
517 
518     while (true)
519     {
520         FunctionMap functionMap;
521         InitializeFunctionMap(root, &functionMap);
522 
523         MonomorphizeTraverser monomorphizer(compiler, symbolTable, compileOptions, &functionMap);
524         root->traverse(&monomorphizer);
525 
526         if (!monomorphizer.getAnyMonomorphized())
527         {
528             break;
529         }
530 
531         if (!monomorphizer.updateTree(compiler, root))
532         {
533             return false;
534         }
535 
536         UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap);
537         root->traverse(&functionUpdater);
538 
539         if (!functionUpdater.updateTree(compiler, root))
540         {
541             return false;
542         }
543     }
544 
545     return true;
546 }
547 }  // namespace sh
548