1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // RewriteSampleMaskVariable.cpp: Find any references to gl_SampleMask and gl_SampleMaskIn, and
7 // rewrite it with ANGLESampleMask or ANGLESampleMaskIn.
8 //
9 
10 #include "compiler/translator/tree_util/RewriteSampleMaskVariable.h"
11 
12 #include "common/bitset_utils.h"
13 #include "common/debug.h"
14 #include "common/utilities.h"
15 #include "compiler/translator/SymbolTable.h"
16 #include "compiler/translator/tree_util/BuiltIn.h"
17 #include "compiler/translator/tree_util/IntermNode_util.h"
18 #include "compiler/translator/tree_util/IntermTraverse.h"
19 #include "compiler/translator/tree_util/RunAtTheBeginningOfShader.h"
20 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
21 
22 namespace sh
23 {
24 namespace
25 {
26 constexpr int kMaxIndexForSampleMaskVar = 0;
27 constexpr int kFullSampleMask           = 0xFFFFFFFF;
28 
29 // Traverse the tree and collect the redeclaration and replace all non constant index references of
30 // gl_SampleMask or gl_SampleMaskIn with constant index references
31 class GLSampleMaskRelatedReferenceTraverser : public TIntermTraverser
32 {
33   public:
GLSampleMaskRelatedReferenceTraverser(const TIntermSymbol ** redeclaredSymOut,const ImmutableString & targetStr)34     GLSampleMaskRelatedReferenceTraverser(const TIntermSymbol **redeclaredSymOut,
35                                           const ImmutableString &targetStr)
36         : TIntermTraverser(true, false, false),
37           mRedeclaredSym(redeclaredSymOut),
38           mTargetStr(targetStr)
39     {
40         *mRedeclaredSym = nullptr;
41     }
42 
visitDeclaration(Visit visit,TIntermDeclaration * node)43     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
44     {
45         // If gl_SampleMask is redeclared, we need to collect its information
46         const TIntermSequence &sequence = *(node->getSequence());
47 
48         if (sequence.size() != 1)
49         {
50             return true;
51         }
52 
53         TIntermTyped *variable = sequence.front()->getAsTyped();
54         TIntermSymbol *symbol  = variable->getAsSymbolNode();
55         if (symbol == nullptr || symbol->getName() != mTargetStr)
56         {
57             return true;
58         }
59 
60         *mRedeclaredSym = symbol;
61 
62         return true;
63     }
64 
visitBinary(Visit visit,TIntermBinary * node)65     bool visitBinary(Visit visit, TIntermBinary *node) override
66     {
67         TOperator op = node->getOp();
68         if (op != EOpIndexDirect && op != EOpIndexIndirect)
69         {
70             return true;
71         }
72         TIntermSymbol *left = node->getLeft()->getAsSymbolNode();
73         if (!left)
74         {
75             return true;
76         }
77         if (left->getName() != mTargetStr)
78         {
79             return true;
80         }
81         const TConstantUnion *constIdx = node->getRight()->getConstantValue();
82         if (!constIdx)
83         {
84             if (node->getRight()->hasSideEffects())
85             {
86                 insertStatementInParentBlock(node->getRight());
87             }
88 
89             queueReplacementWithParent(node, node->getRight(),
90                                        CreateIndexNode(kMaxIndexForSampleMaskVar),
91                                        OriginalNode::IS_DROPPED);
92         }
93 
94         return true;
95     }
96 
97   private:
98     const TIntermSymbol **mRedeclaredSym;
99     const ImmutableString mTargetStr;
100 };
101 
102 }  // anonymous namespace
103 
RewriteSampleMask(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const TIntermTyped * numSamplesUniform)104 ANGLE_NO_DISCARD bool RewriteSampleMask(TCompiler *compiler,
105                                         TIntermBlock *root,
106                                         TSymbolTable *symbolTable,
107                                         const TIntermTyped *numSamplesUniform)
108 {
109     const TIntermSymbol *redeclaredGLSampleMask = nullptr;
110     GLSampleMaskRelatedReferenceTraverser indexTraverser(&redeclaredGLSampleMask,
111                                                          ImmutableString("gl_SampleMask"));
112 
113     root->traverse(&indexTraverser);
114     if (!indexTraverser.updateTree(compiler, root))
115     {
116         return false;
117     }
118 
119     // Retrieve gl_SampleMask variable reference
120     // Search user redeclared it first
121     const TVariable *glSampleMaskVar = nullptr;
122     if (redeclaredGLSampleMask)
123     {
124         glSampleMaskVar = &redeclaredGLSampleMask->variable();
125     }
126     else
127     {
128         // User defined not found, find in built-in table
129         glSampleMaskVar = static_cast<const TVariable *>(
130             symbolTable->findBuiltIn(ImmutableString("gl_SampleMask"), 320));
131     }
132     if (!glSampleMaskVar)
133     {
134         return false;
135     }
136 
137     // Current ANGLE assumes that the maximum number of samples is less than or equal to
138     // VK_SAMPLE_COUNT_32_BIT. So, the size of gl_SampleMask array is always one.
139     const unsigned int arraySizeOfSampleMask = glSampleMaskVar->getType().getOutermostArraySize();
140     ASSERT(arraySizeOfSampleMask == 1);
141 
142     TIntermSymbol *glSampleMaskSymbol = new TIntermSymbol(glSampleMaskVar);
143 
144     // if (ANGLEUniforms.numSamples == 1)
145     // {
146     //     gl_SampleMask[0] = int(0xFFFFFFFF);
147     // }
148     TIntermConstantUnion *singleSampleCount = CreateUIntNode(1);
149     TIntermBinary *equalTo =
150         new TIntermBinary(EOpEqual, numSamplesUniform->deepCopy(), singleSampleCount);
151 
152     TIntermBlock *trueBlock = new TIntermBlock();
153 
154     TIntermBinary *sampleMaskVar = new TIntermBinary(EOpIndexDirect, glSampleMaskSymbol->deepCopy(),
155                                                      CreateIndexNode(kMaxIndexForSampleMaskVar));
156     TIntermConstantUnion *fullSampleMask = CreateIndexNode(kFullSampleMask);
157     TIntermBinary *assignment = new TIntermBinary(EOpAssign, sampleMaskVar, fullSampleMask);
158 
159     trueBlock->appendStatement(assignment);
160 
161     TIntermIfElse *multiSampleOrNot = new TIntermIfElse(equalTo, trueBlock, nullptr);
162 
163     return RunAtTheEndOfShader(compiler, root, multiSampleOrNot, symbolTable);
164 }
165 
RewriteSampleMaskIn(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable)166 ANGLE_NO_DISCARD bool RewriteSampleMaskIn(TCompiler *compiler,
167                                           TIntermBlock *root,
168                                           TSymbolTable *symbolTable)
169 {
170     const TIntermSymbol *redeclaredGLSampleMaskIn = nullptr;
171     GLSampleMaskRelatedReferenceTraverser indexTraverser(&redeclaredGLSampleMaskIn,
172                                                          ImmutableString("gl_SampleMaskIn"));
173 
174     root->traverse(&indexTraverser);
175     if (!indexTraverser.updateTree(compiler, root))
176     {
177         return false;
178     }
179 
180     // Retrieve gl_SampleMaskIn variable reference
181     const TVariable *glSampleMaskInVar = nullptr;
182     glSampleMaskInVar                  = static_cast<const TVariable *>(
183         symbolTable->findBuiltIn(ImmutableString("gl_SampleMaskIn"), 320));
184     if (!glSampleMaskInVar)
185     {
186         return false;
187     }
188 
189     // Current ANGLE assumes that the maximum number of samples is less than or equal to
190     // VK_SAMPLE_COUNT_32_BIT. So, the size of gl_SampleMask array is always one.
191     const unsigned int arraySizeOfSampleMaskIn =
192         glSampleMaskInVar->getType().getOutermostArraySize();
193     ASSERT(arraySizeOfSampleMaskIn == 1);
194 
195     return true;
196 }
197 
198 }  // namespace sh
199