1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // DeclarePerVertexBlocks: Declare gl_PerVertex blocks if not already.
7 //
8 
9 #include "compiler/translator/tree_ops/vulkan/DeclarePerVertexBlocks.h"
10 
11 #include "compiler/translator/Compiler.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/SymbolTable.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 #include "compiler/translator/tree_util/ReplaceVariable.h"
18 
19 namespace sh
20 {
21 namespace
22 {
23 using PerVertexMemberFlags = std::array<bool, 4>;
24 
GetPerVertexFieldIndex(const TQualifier qualifier,const ImmutableString & name)25 int GetPerVertexFieldIndex(const TQualifier qualifier, const ImmutableString &name)
26 {
27     switch (qualifier)
28     {
29         case EvqPosition:
30             ASSERT(name == "gl_Position");
31             return 0;
32         case EvqPointSize:
33             ASSERT(name == "gl_PointSize");
34             return 1;
35         case EvqClipDistance:
36             ASSERT(name == "gl_ClipDistance");
37             return 2;
38         case EvqCullDistance:
39             ASSERT(name == "gl_CullDistance");
40             return 3;
41         default:
42             return -1;
43     }
44 }
45 
46 // Traverser that:
47 //
48 // 1. Declares the input and output gl_PerVertex types and variables if not already (based on shader
49 //    type).
50 // 2. Turns built-in references into indexes into these variables.
51 class DeclarePerVertexBlocksTraverser : public TIntermTraverser
52 {
53   public:
DeclarePerVertexBlocksTraverser(TCompiler * compiler,TSymbolTable * symbolTable,const PerVertexMemberFlags & invariantFlags,const PerVertexMemberFlags & preciseFlags)54     DeclarePerVertexBlocksTraverser(TCompiler *compiler,
55                                     TSymbolTable *symbolTable,
56                                     const PerVertexMemberFlags &invariantFlags,
57                                     const PerVertexMemberFlags &preciseFlags)
58         : TIntermTraverser(true, false, false, symbolTable),
59           mShaderType(compiler->getShaderType()),
60           mResources(compiler->getResources()),
61           mPerVertexInVar(nullptr),
62           mPerVertexOutVar(nullptr),
63           mPerVertexInVarRedeclared(false),
64           mPerVertexOutVarRedeclared(false),
65           mPerVertexOutInvariantFlags(invariantFlags),
66           mPerVertexOutPreciseFlags(preciseFlags)
67     {}
68 
visitSymbol(TIntermSymbol * symbol)69     void visitSymbol(TIntermSymbol *symbol) override
70     {
71         const TVariable *variable = &symbol->variable();
72         const TType *type         = &variable->getType();
73 
74         // Replace gl_out if necessary.
75         if (mShaderType == GL_TESS_CONTROL_SHADER && type->getQualifier() == EvqPerVertexOut)
76         {
77             ASSERT(variable->name() == "gl_out");
78 
79             // Declare gl_out if not already.
80             if (mPerVertexOutVar == nullptr)
81             {
82                 // Record invariant and precise qualifiers used on the fields so they would be
83                 // applied to the replacement gl_out.
84                 for (const TField *field : type->getInterfaceBlock()->fields())
85                 {
86                     const TType &fieldType = *field->type();
87                     const int fieldIndex =
88                         GetPerVertexFieldIndex(fieldType.getQualifier(), field->name());
89                     ASSERT(fieldIndex >= 0);
90 
91                     if (fieldType.isInvariant())
92                     {
93                         mPerVertexOutInvariantFlags[fieldIndex] = true;
94                     }
95                     if (fieldType.isPrecise())
96                     {
97                         mPerVertexOutPreciseFlags[fieldIndex] = true;
98                     }
99                 }
100 
101                 declareDefaultGlOut();
102             }
103 
104             if (mPerVertexOutVarRedeclared)
105             {
106                 queueReplacement(new TIntermSymbol(mPerVertexOutVar), OriginalNode::IS_DROPPED);
107             }
108 
109             return;
110         }
111 
112         // Replace gl_in if necessary.
113         if ((mShaderType == GL_TESS_CONTROL_SHADER || mShaderType == GL_TESS_EVALUATION_SHADER ||
114              mShaderType == GL_GEOMETRY_SHADER) &&
115             type->getQualifier() == EvqPerVertexIn)
116         {
117             ASSERT(variable->name() == "gl_in");
118 
119             // Declare gl_in if not already.
120             if (mPerVertexInVar == nullptr)
121             {
122                 declareDefaultGlIn();
123             }
124 
125             if (mPerVertexInVarRedeclared)
126             {
127                 queueReplacement(new TIntermSymbol(mPerVertexInVar), OriginalNode::IS_DROPPED);
128             }
129 
130             return;
131         }
132 
133         // Turn gl_Position, gl_PointSize, gl_ClipDistance and gl_CullDistance into references to
134         // the output gl_PerVertex.  Note that the default gl_PerVertex is declared as follows:
135         //
136         //     out gl_PerVertex
137         //     {
138         //         vec4 gl_Position;
139         //         float gl_PointSize;
140         //         float gl_ClipDistance[];
141         //         float gl_CullDistance[];
142         //     };
143         //
144 
145         if (variable->symbolType() != SymbolType::BuiltIn)
146         {
147             ASSERT(variable->name() != "gl_Position" && variable->name() != "gl_PointSize" &&
148                    variable->name() != "gl_ClipDistance" && variable->name() != "gl_CullDistance");
149 
150             return;
151         }
152 
153         // If this built-in was already visited, reuse the variable defined for it.
154         auto replacement = mVariableMap.find(variable);
155         if (replacement != mVariableMap.end())
156         {
157             queueReplacement(replacement->second->deepCopy(), OriginalNode::IS_DROPPED);
158             return;
159         }
160 
161         const int fieldIndex = GetPerVertexFieldIndex(type->getQualifier(), variable->name());
162 
163         // Not the built-in we are looking for.
164         if (fieldIndex < 0)
165         {
166             return;
167         }
168 
169         // Declare the output gl_PerVertex if not already.
170         if (mPerVertexOutVar == nullptr)
171         {
172             declareDefaultGlOut();
173         }
174 
175         TType *newType = new TType(*type);
176         newType->setInterfaceBlockField(mPerVertexOutVar->getType().getInterfaceBlock(),
177                                         fieldIndex);
178 
179         TVariable *newVariable = new TVariable(mSymbolTable, variable->name(), newType,
180                                                variable->symbolType(), variable->extensions());
181 
182         TIntermSymbol *newSymbol = new TIntermSymbol(newVariable);
183         mVariableMap[variable]   = newSymbol;
184 
185         queueReplacement(newSymbol, OriginalNode::IS_DROPPED);
186     }
187 
getRedeclaredPerVertexOutVar()188     const TVariable *getRedeclaredPerVertexOutVar()
189     {
190         return mPerVertexOutVarRedeclared ? mPerVertexOutVar : nullptr;
191     }
192 
getRedeclaredPerVertexInVar()193     const TVariable *getRedeclaredPerVertexInVar()
194     {
195         return mPerVertexInVarRedeclared ? mPerVertexInVar : nullptr;
196     }
197 
198   private:
declarePerVertex(TQualifier qualifier,uint32_t arraySize,ImmutableString & variableName)199     const TVariable *declarePerVertex(TQualifier qualifier,
200                                       uint32_t arraySize,
201                                       ImmutableString &variableName)
202     {
203         TFieldList *fields = new TFieldList;
204 
205         const TType *vec4Type  = StaticType::GetBasic<EbtFloat, 4>();
206         const TType *floatType = StaticType::GetBasic<EbtFloat, 1>();
207 
208         TType *positionType     = new TType(*vec4Type);
209         TType *pointSizeType    = new TType(*floatType);
210         TType *clipDistanceType = new TType(*floatType);
211         TType *cullDistanceType = new TType(*floatType);
212 
213         positionType->setQualifier(EvqPosition);
214         pointSizeType->setQualifier(EvqPointSize);
215         clipDistanceType->setQualifier(EvqClipDistance);
216         cullDistanceType->setQualifier(EvqCullDistance);
217 
218         clipDistanceType->makeArray(mResources.MaxClipDistances);
219         cullDistanceType->makeArray(mResources.MaxCullDistances);
220 
221         if (qualifier == EvqPerVertexOut)
222         {
223             positionType->setInvariant(mPerVertexOutInvariantFlags[0]);
224             pointSizeType->setInvariant(mPerVertexOutInvariantFlags[1]);
225             clipDistanceType->setInvariant(mPerVertexOutInvariantFlags[2]);
226             cullDistanceType->setInvariant(mPerVertexOutInvariantFlags[3]);
227 
228             positionType->setPrecise(mPerVertexOutPreciseFlags[0]);
229             pointSizeType->setPrecise(mPerVertexOutPreciseFlags[1]);
230             clipDistanceType->setPrecise(mPerVertexOutPreciseFlags[2]);
231             cullDistanceType->setPrecise(mPerVertexOutPreciseFlags[3]);
232         }
233 
234         fields->push_back(new TField(positionType, ImmutableString("gl_Position"), TSourceLoc(),
235                                      SymbolType::AngleInternal));
236         fields->push_back(new TField(pointSizeType, ImmutableString("gl_PointSize"), TSourceLoc(),
237                                      SymbolType::AngleInternal));
238         fields->push_back(new TField(clipDistanceType, ImmutableString("gl_ClipDistance"),
239                                      TSourceLoc(), SymbolType::AngleInternal));
240         fields->push_back(new TField(cullDistanceType, ImmutableString("gl_CullDistance"),
241                                      TSourceLoc(), SymbolType::AngleInternal));
242 
243         TInterfaceBlock *interfaceBlock =
244             new TInterfaceBlock(mSymbolTable, ImmutableString("gl_PerVertex"), fields,
245                                 TLayoutQualifier::Create(), SymbolType::AngleInternal);
246 
247         TType *interfaceBlockType =
248             new TType(interfaceBlock, qualifier, TLayoutQualifier::Create());
249         if (arraySize > 0)
250         {
251             interfaceBlockType->makeArray(arraySize);
252         }
253 
254         TVariable *interfaceBlockVar =
255             new TVariable(mSymbolTable, variableName, interfaceBlockType,
256                           variableName.empty() ? SymbolType::Empty : SymbolType::AngleInternal);
257 
258         return interfaceBlockVar;
259     }
260 
declareDefaultGlOut()261     void declareDefaultGlOut()
262     {
263         ASSERT(!mPerVertexOutVarRedeclared);
264 
265         // For tessellation control shaders, gl_out is an array of MaxPatchVertices
266         // For other shaders, there's no explicit name or array size
267 
268         ImmutableString varName("");
269         uint32_t arraySize = 0;
270         if (mShaderType == GL_TESS_CONTROL_SHADER)
271         {
272             varName   = ImmutableString("gl_out");
273             arraySize = mResources.MaxPatchVertices;
274         }
275 
276         mPerVertexOutVar           = declarePerVertex(EvqPerVertexOut, arraySize, varName);
277         mPerVertexOutVarRedeclared = true;
278     }
279 
declareDefaultGlIn()280     void declareDefaultGlIn()
281     {
282         ASSERT(!mPerVertexInVarRedeclared);
283 
284         // For tessellation shaders, gl_in is an array of MaxPatchVertices.
285         // For geometry shaders, gl_in is sized based on the primitive type.
286 
287         ImmutableString varName("gl_in");
288         uint32_t arraySize = mResources.MaxPatchVertices;
289         if (mShaderType == GL_GEOMETRY_SHADER)
290         {
291             arraySize =
292                 mSymbolTable->getGlInVariableWithArraySize()->getType().getOutermostArraySize();
293         }
294 
295         mPerVertexInVar           = declarePerVertex(EvqPerVertexIn, arraySize, varName);
296         mPerVertexInVarRedeclared = true;
297     }
298 
299     GLenum mShaderType;
300     const ShBuiltInResources &mResources;
301 
302     const TVariable *mPerVertexInVar;
303     const TVariable *mPerVertexOutVar;
304 
305     bool mPerVertexInVarRedeclared;
306     bool mPerVertexOutVarRedeclared;
307 
308     // A map of already replaced built-in variables.
309     VariableReplacementMap mVariableMap;
310 
311     // Whether each field is invariant or precise.
312     PerVertexMemberFlags mPerVertexOutInvariantFlags;
313     PerVertexMemberFlags mPerVertexOutPreciseFlags;
314 };
315 
AddPerVertexDecl(TIntermBlock * root,const TVariable * variable)316 void AddPerVertexDecl(TIntermBlock *root, const TVariable *variable)
317 {
318     if (variable == nullptr)
319     {
320         return;
321     }
322 
323     TIntermDeclaration *decl = new TIntermDeclaration;
324     TIntermSymbol *symbol    = new TIntermSymbol(variable);
325     decl->appendDeclarator(symbol);
326 
327     // Insert the declaration before the first function.
328     size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
329     root->insertChildNodes(firstFunctionIndex, {decl});
330 }
331 }  // anonymous namespace
332 
DeclarePerVertexBlocks(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)333 bool DeclarePerVertexBlocks(TCompiler *compiler, TIntermBlock *root, TSymbolTable *symbolTable)
334 {
335     if (compiler->getShaderType() == GL_COMPUTE_SHADER ||
336         compiler->getShaderType() == GL_FRAGMENT_SHADER)
337     {
338         return true;
339     }
340 
341     // First, visit all global qualifier declarations and find which built-ins are invariant or
342     // precise.
343     PerVertexMemberFlags invariantFlags = {};
344     PerVertexMemberFlags preciseFlags   = {};
345 
346     TIntermSequence withoutPerVertexGlobalQualifierDeclarations;
347 
348     for (TIntermNode *node : *root->getSequence())
349     {
350         TIntermGlobalQualifierDeclaration *asGlobalQualifierDecl =
351             node->getAsGlobalQualifierDeclarationNode();
352         if (asGlobalQualifierDecl == nullptr)
353         {
354             withoutPerVertexGlobalQualifierDeclarations.push_back(node);
355             continue;
356         }
357 
358         TIntermSymbol *symbol = asGlobalQualifierDecl->getSymbol();
359 
360         const int fieldIndex =
361             GetPerVertexFieldIndex(symbol->getType().getQualifier(), symbol->getName());
362         if (fieldIndex < 0)
363         {
364             withoutPerVertexGlobalQualifierDeclarations.push_back(node);
365             continue;
366         }
367 
368         if (asGlobalQualifierDecl->isInvariant())
369         {
370             invariantFlags[fieldIndex] = true;
371         }
372         else if (asGlobalQualifierDecl->isPrecise())
373         {
374             preciseFlags[fieldIndex] = true;
375         }
376     }
377 
378     // Remove the global qualifier declarations for the gl_PerVertex members.
379     root->replaceAllChildren(withoutPerVertexGlobalQualifierDeclarations);
380 
381     // If #pragma STDGL invariant(all) is specified, make all outputs invariant.
382     if (compiler->getPragma().stdgl.invariantAll)
383     {
384         std::fill(invariantFlags.begin(), invariantFlags.end(), true);
385     }
386 
387     // Then declare the in and out gl_PerVertex I/O blocks.
388     DeclarePerVertexBlocksTraverser traverser(compiler, symbolTable, invariantFlags, preciseFlags);
389     root->traverse(&traverser);
390     if (!traverser.updateTree(compiler, root))
391     {
392         return false;
393     }
394 
395     AddPerVertexDecl(root, traverser.getRedeclaredPerVertexOutVar());
396     AddPerVertexDecl(root, traverser.getRedeclaredPerVertexInVar());
397 
398     return compiler->validateAST(root);
399 }
400 }  // namespace sh
401