1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // MonomorphizeUnsupportedFunctionsInVulkanGLSL: Monomorphize functions that are called with
7 // parameters that are not compatible with Vulkan GLSL.
8 //
9
10 #include "compiler/translator/tree_ops/vulkan/MonomorphizeUnsupportedFunctionsInVulkanGLSL.h"
11
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18
19 namespace sh
20 {
21 namespace
22 {
23 struct Argument
24 {
25 size_t argumentIndex;
26 TIntermTyped *argument;
27 };
28
29 struct FunctionData
30 {
31 // Whether the original function is used. If this is false, the function can be removed because
32 // all callers have been modified.
33 bool isOriginalUsed;
34 // The original definition of the function, used to create the monomorphized version.
35 TIntermFunctionDefinition *originalDefinition;
36 // List of monomorphized versions of this function. They will be added next to the original
37 // version (or replace it).
38 TVector<TIntermFunctionDefinition *> monomorphizedDefinitions;
39 };
40
41 using FunctionMap = angle::HashMap<const TFunction *, FunctionData>;
42
43 // Traverse the function definitions and initialize the map. Allows visitAggregate to have access
44 // to TIntermFunctionDefinition even when the function is only forward declared at that point.
InitializeFunctionMap(TIntermBlock * root,FunctionMap * functionMapOut)45 void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut)
46 {
47 TIntermSequence &sequence = *root->getSequence();
48
49 for (TIntermNode *node : sequence)
50 {
51 TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition();
52 if (asFuncDef != nullptr)
53 {
54 const TFunction *function = asFuncDef->getFunction();
55 ASSERT(function && functionMapOut->find(function) == functionMapOut->end());
56 (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}};
57 }
58 }
59 }
60
GetBaseUniform(TIntermTyped * node,bool * isSamplerInStructOut)61 const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut)
62 {
63 *isSamplerInStructOut = false;
64
65 while (node->getAsBinaryNode())
66 {
67 TIntermBinary *asBinary = node->getAsBinaryNode();
68
69 TOperator op = asBinary->getOp();
70
71 // No opaque uniform can be inside an interface block.
72 if (op == EOpIndexDirectInterfaceBlock)
73 {
74 return nullptr;
75 }
76
77 if (op == EOpIndexDirectStruct)
78 {
79 *isSamplerInStructOut = true;
80 }
81
82 node = asBinary->getLeft();
83 }
84
85 // Only interested in uniform opaque types. If a function call within another function uses
86 // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the
87 // calling function is monomorphized.
88 if (node->getType().getQualifier() != EvqUniform)
89 {
90 return nullptr;
91 }
92
93 ASSERT(IsOpaqueType(node->getType().getBasicType()) ||
94 node->getType().isStructureContainingSamplers());
95
96 TIntermSymbol *asSymbol = node->getAsSymbolNode();
97 ASSERT(asSymbol);
98
99 return &asSymbol->variable();
100 }
101
ExtractSideEffects(TSymbolTable * symbolTable,TIntermTyped * node,TIntermSequence * replacementIndices)102 TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable,
103 TIntermTyped *node,
104 TIntermSequence *replacementIndices)
105 {
106 TIntermTyped *withoutSideEffects = node->deepCopy();
107
108 for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary;
109 asBinary = asBinary->getLeft()->getAsBinaryNode())
110 {
111 TOperator op = asBinary->getOp();
112 TIntermTyped *index = asBinary->getRight();
113
114 if (op == EOpIndexDirectStruct)
115 {
116 break;
117 }
118
119 // No side effects with constant expressions.
120 if (op == EOpIndexDirect)
121 {
122 ASSERT(index->getAsConstantUnion());
123 continue;
124 }
125
126 ASSERT(op == EOpIndexIndirect);
127
128 // If the index is a symbol, there's no side effect, so leave it as-is.
129 if (index->getAsSymbolNode())
130 {
131 continue;
132 }
133
134 // Otherwise create a temp variable initialized with the index and use that temp variable as
135 // the index.
136 TIntermDeclaration *tempDecl = nullptr;
137 TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl);
138
139 replacementIndices->push_back(tempDecl);
140 asBinary->replaceChildNode(index, new TIntermSymbol(tempVar));
141 }
142
143 return withoutSideEffects;
144 }
145
CreateMonomorphizedFunctionCallArgs(const TIntermSequence & originalCallArguments,const TVector<Argument> & replacedArguments,TIntermSequence * substituteArgsOut)146 void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments,
147 const TVector<Argument> &replacedArguments,
148 TIntermSequence *substituteArgsOut)
149 {
150 size_t nextReplacedArg = 0;
151 for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex)
152 {
153 if (nextReplacedArg >= replacedArguments.size() ||
154 argIndex != replacedArguments[nextReplacedArg].argumentIndex)
155 {
156 // Not replaced, keep argument as is.
157 substituteArgsOut->push_back(originalCallArguments[argIndex]);
158 }
159 else
160 {
161 TIntermTyped *argument = replacedArguments[nextReplacedArg].argument;
162
163 // Iterate over indices of the argument and create a new arg for every non-const
164 // index. Note that the index itself may be an expression, and it may require further
165 // substitution in the next pass.
166 while (argument->getAsBinaryNode())
167 {
168 TIntermBinary *asBinary = argument->getAsBinaryNode();
169 if (asBinary->getOp() == EOpIndexIndirect)
170 {
171 TIntermTyped *index = asBinary->getRight();
172 substituteArgsOut->push_back(index->deepCopy());
173 }
174 argument = asBinary->getLeft();
175 }
176
177 ++nextReplacedArg;
178 }
179 }
180 }
181
MonomorphizeFunction(TSymbolTable * symbolTable,const TFunction * original,TVector<Argument> * replacedArguments,VariableReplacementMap * argumentMapOut)182 const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable,
183 const TFunction *original,
184 TVector<Argument> *replacedArguments,
185 VariableReplacementMap *argumentMapOut)
186 {
187 TFunction *substituteFunction =
188 new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
189 &original->getReturnType(), original->isKnownToNotHaveSideEffects());
190
191 size_t nextReplacedArg = 0;
192 for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex)
193 {
194 const TVariable *originalParam = original->getParam(paramIndex);
195
196 if (nextReplacedArg >= replacedArguments->size() ||
197 paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex)
198 {
199 TVariable *substituteArgument =
200 new TVariable(symbolTable, originalParam->name(), &originalParam->getType(),
201 originalParam->symbolType());
202 // Not replaced, add an identical parameter.
203 substituteFunction->addParameter(substituteArgument);
204 (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument);
205 }
206 else
207 {
208 TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument;
209 (*argumentMapOut)[originalParam] = substituteArgument;
210
211 // Iterate over indices of the argument and create a new parameter for every non-const
212 // index (which may be an expression). Replace the symbol in the argument with a
213 // variable of the index type. This is later used to replace the parameter in the
214 // function body.
215 while (substituteArgument->getAsBinaryNode())
216 {
217 TIntermBinary *asBinary = substituteArgument->getAsBinaryNode();
218 if (asBinary->getOp() == EOpIndexIndirect)
219 {
220 TIntermTyped *index = asBinary->getRight();
221 TType *indexType = new TType(index->getType());
222 indexType->setQualifier(EvqIn);
223
224 TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType,
225 SymbolType::AngleInternal);
226 substituteFunction->addParameter(param);
227
228 // The argument now uses the function parameters as indices.
229 asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param));
230 }
231 substituteArgument = asBinary->getLeft();
232 }
233
234 ++nextReplacedArg;
235 }
236 }
237
238 return substituteFunction;
239 }
240
241 class MonomorphizeTraverser final : public TIntermTraverser
242 {
243 public:
MonomorphizeTraverser(TCompiler * compiler,TSymbolTable * symbolTable,ShCompileOptions compileOptions,FunctionMap * functionMap)244 explicit MonomorphizeTraverser(TCompiler *compiler,
245 TSymbolTable *symbolTable,
246 ShCompileOptions compileOptions,
247 FunctionMap *functionMap)
248 : TIntermTraverser(true, false, false, symbolTable),
249 mCompiler(compiler),
250 mCompileOptions(compileOptions),
251 mFunctionMap(functionMap)
252 {}
253
visitAggregate(Visit visit,TIntermAggregate * node)254 bool visitAggregate(Visit visit, TIntermAggregate *node) override
255 {
256 if (node->getOp() != EOpCallFunctionInAST)
257 {
258 return true;
259 }
260
261 const TFunction *function = node->getFunction();
262 ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end());
263
264 FunctionData &data = (*mFunctionMap)[function];
265
266 TIntermFunctionDefinition *monomorphized =
267 processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed);
268 if (monomorphized)
269 {
270 data.monomorphizedDefinitions.push_back(monomorphized);
271 }
272
273 return true;
274 }
275
getAnyMonomorphized() const276 bool getAnyMonomorphized() const { return mAnyMonomorphized; }
277
278 private:
processFunctionCall(TIntermAggregate * functionCall,TIntermFunctionDefinition * originalDefinition,bool * isOriginalUsedOut)279 TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall,
280 TIntermFunctionDefinition *originalDefinition,
281 bool *isOriginalUsedOut)
282 {
283 const TFunction *function = functionCall->getFunction();
284 const TIntermSequence &callArguments = *functionCall->getSequence();
285
286 TVector<Argument> replacedArguments;
287 TIntermSequence replacementIndices;
288
289 // Go through function call arguments, and see if any is used in an unsupported way.
290 for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex)
291 {
292 TIntermTyped *callArgument = callArguments[argIndex]->getAsTyped();
293 const TVariable *funcArgument = function->getParam(argIndex);
294
295 // Only interested in opaque uniforms and structs that contain samplers.
296 const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType());
297 const bool isStructContainingSamplers =
298 funcArgument->getType().isStructureContainingSamplers();
299 if (!isOpaqueType && !isStructContainingSamplers)
300 {
301 continue;
302 }
303
304 // If not uniform (the variable was itself a function parameter), don't process it in
305 // this pass, as we don't know which actual uniform it corresponds to.
306 bool isSamplerInStruct = false;
307 const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct);
308 if (uniform == nullptr)
309 {
310 continue;
311 }
312
313 // Conditions for monomorphization:
314 //
315 // - If the parameter is a structure that contains samplers (so in RewriteStructSamplers
316 // we don't need to rewrite the functions to accept multiple parameters split from the
317 // struct), or
318 // - If the opaque uniform is a sampler in a struct (which can create an array-of-array
319 // situation), and the function expects an array of samplers, or
320 // - If the opaque uniform is an array of array of sampler or image, and it's partially
321 // subscripted (i.e. the function itself expects an array), or
322 // - The opaque uniform is an atomic counter
323 // - The opaque uniform is a samplerCube and ES2's cube sampling emulation is requested.
324 // - The opaque uniform is an image* with r32f format.
325 //
326 const TType &type = uniform->getType();
327 const bool isArrayOfArrayOfSamplerOrImage =
328 (type.isSampler() || type.isImage()) && type.isArrayOfArrays();
329 const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray();
330 const bool isAtomicCounter = type.isAtomicCounter();
331 const bool isSamplerCubeEmulation =
332 type.isSamplerCube() &&
333 (mCompileOptions & SH_EMULATE_SEAMFUL_CUBE_MAP_SAMPLING) != 0;
334 const bool isR32fImage =
335 type.isImage() && type.getLayoutQualifier().imageInternalFormat == EiifR32F;
336
337 if (!(isStructContainingSamplers ||
338 (isSamplerInStruct && isParameterArrayOfOpaqueType) ||
339 (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType) ||
340 isAtomicCounter || isSamplerCubeEmulation || isR32fImage))
341 {
342 continue;
343 }
344
345 // Copy the argument and extract the side effects.
346 TIntermTyped *argument =
347 ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices);
348
349 replacedArguments.push_back({argIndex, argument});
350 }
351
352 if (replacedArguments.empty())
353 {
354 *isOriginalUsedOut = true;
355 return nullptr;
356 }
357
358 mAnyMonomorphized = true;
359
360 insertStatementsInParentBlock(replacementIndices);
361
362 // Create the arguments for the substitute function call. Done before monomorphizing the
363 // function, which transforms the arguments to what needs to be replaced in the function
364 // body.
365 TIntermSequence newCallArgs;
366 CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs);
367
368 // Duplicate the function and substitute the replaced arguments with only the non-const
369 // indices. Additionally, substitute the non-const indices of arguments with the new
370 // function parameters.
371 VariableReplacementMap argumentMap;
372 const TFunction *monomorphized =
373 MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap);
374
375 // Replace this function call with a call to the new one.
376 queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs),
377 OriginalNode::IS_DROPPED);
378
379 // Create a new function definition, with the body of the old function but with the replaced
380 // parameters substituted with the calling expressions.
381 TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized);
382 TIntermBlock *substituteBlock = originalDefinition->getBody()->deepCopy();
383 GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap);
384 bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap);
385 ASSERT(valid);
386
387 return new TIntermFunctionDefinition(substitutePrototype, substituteBlock);
388 }
389
390 TCompiler *mCompiler;
391 ShCompileOptions mCompileOptions;
392 bool mAnyMonomorphized = false;
393
394 // Map of original to monomorphized functions.
395 FunctionMap *mFunctionMap;
396 };
397
398 class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser
399 {
400 public:
UpdateFunctionsDefinitionsTraverser(TSymbolTable * symbolTable,const FunctionMap & functionMap)401 explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable,
402 const FunctionMap &functionMap)
403 : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap)
404 {}
405
visitFunctionPrototype(TIntermFunctionPrototype * node)406 void visitFunctionPrototype(TIntermFunctionPrototype *node) override
407 {
408 const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr;
409 if (isInFunctionDefinition)
410 {
411 return;
412 }
413
414 // Add to and possibly replace the function prototype with replacement prototypes.
415 const TFunction *function = node->getFunction();
416 ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
417
418 const FunctionData &data = mFunctionMap.at(function);
419
420 // If nothing to do, leave it be.
421 if (data.monomorphizedDefinitions.empty())
422 {
423 ASSERT(data.isOriginalUsed);
424 return;
425 }
426
427 // Replace the prototype with itself (if function is still used) as well as any
428 // monomorphized versions.
429 TIntermSequence replacement;
430 if (data.isOriginalUsed)
431 {
432 replacement.push_back(node);
433 }
434 for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
435 {
436 replacement.push_back(new TIntermFunctionPrototype(
437 monomorphizedDefinition->getFunctionPrototype()->getFunction()));
438 }
439 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
440 std::move(replacement));
441 }
442
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)443 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
444 {
445 // Add to and possibly replace the function definition with replacement definitions.
446 const TFunction *function = node->getFunction();
447 ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
448
449 const FunctionData &data = mFunctionMap.at(function);
450
451 // If nothing to do, leave it be.
452 if (data.monomorphizedDefinitions.empty())
453 {
454 ASSERT(data.isOriginalUsed || function->name() == "main");
455 return false;
456 }
457
458 // Replace the definition with itself (if function is still used) as well as any
459 // monomorphized versions.
460 TIntermSequence replacement;
461 if (data.isOriginalUsed)
462 {
463 replacement.push_back(node);
464 }
465 for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
466 {
467 replacement.push_back(monomorphizedDefinition);
468 }
469 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
470 std::move(replacement));
471
472 return false;
473 }
474
475 private:
476 const FunctionMap &mFunctionMap;
477 };
478
SortDeclarations(TIntermBlock * root)479 void SortDeclarations(TIntermBlock *root)
480 {
481 TIntermSequence *original = root->getSequence();
482
483 TIntermSequence replacement;
484 TIntermSequence functionDefs;
485
486 // Accumulate non-function-definition declarations in |replacement| and function definitions in
487 // |functionDefs|.
488 for (TIntermNode *node : *original)
489 {
490 if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode())
491 {
492 functionDefs.push_back(node);
493 }
494 else
495 {
496 replacement.push_back(node);
497 }
498 }
499
500 // Append function definitions to |replacement|.
501 replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end());
502
503 // Replace root's sequence with |replacement|.
504 root->replaceAllChildren(replacement);
505 }
506 } // anonymous namespace
507
MonomorphizeUnsupportedFunctionsInVulkanGLSL(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,ShCompileOptions compileOptions)508 bool MonomorphizeUnsupportedFunctionsInVulkanGLSL(TCompiler *compiler,
509 TIntermBlock *root,
510 TSymbolTable *symbolTable,
511 ShCompileOptions compileOptions)
512 {
513 // First, sort out the declarations such that all non-function declarations are placed before
514 // function definitions. This way when the function is replaced with one that references said
515 // declarations (i.e. uniforms), the uniform declaration is already present above it.
516 SortDeclarations(root);
517
518 while (true)
519 {
520 FunctionMap functionMap;
521 InitializeFunctionMap(root, &functionMap);
522
523 MonomorphizeTraverser monomorphizer(compiler, symbolTable, compileOptions, &functionMap);
524 root->traverse(&monomorphizer);
525
526 if (!monomorphizer.getAnyMonomorphized())
527 {
528 break;
529 }
530
531 if (!monomorphizer.updateTree(compiler, root))
532 {
533 return false;
534 }
535
536 UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap);
537 root->traverse(&functionUpdater);
538
539 if (!functionUpdater.updateTree(compiler, root))
540 {
541 return false;
542 }
543 }
544
545 return true;
546 }
547 } // namespace sh
548