1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 // Analysis of the AST needed for HLSL generation
8 
9 #include "compiler/translator/ASTMetadataHLSL.h"
10 
11 #include "compiler/translator/CallDAG.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermTraverse.h"
14 
15 namespace sh
16 {
17 
18 namespace
19 {
20 
21 // Class used to traverse the AST of a function definition, checking if the
22 // function uses a gradient, and writing the set of control flow using gradients.
23 // It assumes that the analysis has already been made for the function's
24 // callees.
25 class PullGradient : public TIntermTraverser
26 {
27   public:
PullGradient(MetadataList * metadataList,size_t index,const CallDAG & dag)28     PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag)
29         : TIntermTraverser(true, false, true),
30           mMetadataList(metadataList),
31           mMetadata(&(*metadataList)[index]),
32           mIndex(index),
33           mDag(dag)
34     {
35         ASSERT(index < metadataList->size());
36 
37         // ESSL 100 builtin gradient functions
38         mGradientBuiltinFunctions.insert(ImmutableString("texture2D"));
39         mGradientBuiltinFunctions.insert(ImmutableString("texture2DProj"));
40         mGradientBuiltinFunctions.insert(ImmutableString("textureCube"));
41 
42         // ESSL 300 builtin gradient functions
43         mGradientBuiltinFunctions.insert(ImmutableString("dFdx"));
44         mGradientBuiltinFunctions.insert(ImmutableString("dFdy"));
45         mGradientBuiltinFunctions.insert(ImmutableString("fwidth"));
46         mGradientBuiltinFunctions.insert(ImmutableString("texture"));
47         mGradientBuiltinFunctions.insert(ImmutableString("textureProj"));
48         mGradientBuiltinFunctions.insert(ImmutableString("textureOffset"));
49         mGradientBuiltinFunctions.insert(ImmutableString("textureProjOffset"));
50 
51         // ESSL 310 doesn't add builtin gradient functions
52     }
53 
traverse(TIntermFunctionDefinition * node)54     void traverse(TIntermFunctionDefinition *node)
55     {
56         node->traverse(this);
57         ASSERT(mParents.empty());
58     }
59 
60     // Called when a gradient operation or a call to a function using a gradient is found.
onGradient()61     void onGradient()
62     {
63         mMetadata->mUsesGradient = true;
64         // Mark the latest control flow as using a gradient.
65         if (!mParents.empty())
66         {
67             mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
68         }
69     }
70 
visitControlFlow(Visit visit,TIntermNode * node)71     void visitControlFlow(Visit visit, TIntermNode *node)
72     {
73         if (visit == PreVisit)
74         {
75             mParents.push_back(node);
76         }
77         else if (visit == PostVisit)
78         {
79             ASSERT(mParents.back() == node);
80             mParents.pop_back();
81             // A control flow's using a gradient means its parents are too.
82             if (mMetadata->mControlFlowsContainingGradient.count(node) > 0 && !mParents.empty())
83             {
84                 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
85             }
86         }
87     }
88 
visitLoop(Visit visit,TIntermLoop * loop)89     bool visitLoop(Visit visit, TIntermLoop *loop) override
90     {
91         visitControlFlow(visit, loop);
92         return true;
93     }
94 
visitIfElse(Visit visit,TIntermIfElse * ifElse)95     bool visitIfElse(Visit visit, TIntermIfElse *ifElse) override
96     {
97         visitControlFlow(visit, ifElse);
98         return true;
99     }
100 
visitAggregate(Visit visit,TIntermAggregate * node)101     bool visitAggregate(Visit visit, TIntermAggregate *node) override
102     {
103         if (visit == PreVisit)
104         {
105             if (node->getOp() == EOpCallFunctionInAST)
106             {
107                 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
108                 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
109 
110                 if ((*mMetadataList)[calleeIndex].mUsesGradient)
111                 {
112                     onGradient();
113                 }
114             }
115             else if (BuiltInGroup::IsBuiltIn(node->getOp()) && !BuiltInGroup::IsMath(node->getOp()))
116             {
117                 if (mGradientBuiltinFunctions.find(node->getFunction()->name()) !=
118                     mGradientBuiltinFunctions.end())
119                 {
120                     onGradient();
121                 }
122             }
123         }
124 
125         return true;
126     }
127 
128   private:
129     MetadataList *mMetadataList;
130     ASTMetadataHLSL *mMetadata;
131     size_t mIndex;
132     const CallDAG &mDag;
133 
134     // Contains a stack of the control flow nodes that are parents of the node being
135     // currently visited. It is used to mark control flows using a gradient.
136     std::vector<TIntermNode *> mParents;
137 
138     // A list of builtin functions that use gradients
139     std::set<ImmutableString> mGradientBuiltinFunctions;
140 };
141 
142 // Traverses the AST of a function definition to compute the the discontinuous loops
143 // and the if statements containing gradient loops. It assumes that the gradient loops
144 // (loops that contain a gradient) have already been computed and that it has already
145 // traversed the current function's callees.
146 class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
147 {
148   public:
PullComputeDiscontinuousAndGradientLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)149     PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList,
150                                              size_t index,
151                                              const CallDAG &dag)
152         : TIntermTraverser(true, false, true),
153           mMetadataList(metadataList),
154           mMetadata(&(*metadataList)[index]),
155           mIndex(index),
156           mDag(dag)
157     {}
158 
traverse(TIntermFunctionDefinition * node)159     void traverse(TIntermFunctionDefinition *node)
160     {
161         node->traverse(this);
162         ASSERT(mLoopsAndSwitches.empty());
163         ASSERT(mIfs.empty());
164     }
165 
166     // Called when traversing a gradient loop or a call to a function with a
167     // gradient loop in its call graph.
onGradientLoop()168     void onGradientLoop()
169     {
170         mMetadata->mHasGradientLoopInCallGraph = true;
171         // Mark the latest if as using a discontinuous loop.
172         if (!mIfs.empty())
173         {
174             mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
175         }
176     }
177 
visitLoop(Visit visit,TIntermLoop * loop)178     bool visitLoop(Visit visit, TIntermLoop *loop) override
179     {
180         if (visit == PreVisit)
181         {
182             mLoopsAndSwitches.push_back(loop);
183 
184             if (mMetadata->hasGradientInCallGraph(loop))
185             {
186                 onGradientLoop();
187             }
188         }
189         else if (visit == PostVisit)
190         {
191             ASSERT(mLoopsAndSwitches.back() == loop);
192             mLoopsAndSwitches.pop_back();
193         }
194 
195         return true;
196     }
197 
visitIfElse(Visit visit,TIntermIfElse * node)198     bool visitIfElse(Visit visit, TIntermIfElse *node) override
199     {
200         if (visit == PreVisit)
201         {
202             mIfs.push_back(node);
203         }
204         else if (visit == PostVisit)
205         {
206             ASSERT(mIfs.back() == node);
207             mIfs.pop_back();
208             // An if using a discontinuous loop means its parents ifs are also discontinuous.
209             if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty())
210             {
211                 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
212             }
213         }
214 
215         return true;
216     }
217 
visitBranch(Visit visit,TIntermBranch * node)218     bool visitBranch(Visit visit, TIntermBranch *node) override
219     {
220         if (visit == PreVisit)
221         {
222             switch (node->getFlowOp())
223             {
224                 case EOpBreak:
225                 {
226                     ASSERT(!mLoopsAndSwitches.empty());
227                     TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode();
228                     if (loop != nullptr)
229                     {
230                         mMetadata->mDiscontinuousLoops.insert(loop);
231                     }
232                 }
233                 break;
234                 case EOpContinue:
235                 {
236                     ASSERT(!mLoopsAndSwitches.empty());
237                     TIntermLoop *loop = nullptr;
238                     size_t i          = mLoopsAndSwitches.size();
239                     while (loop == nullptr && i > 0)
240                     {
241                         --i;
242                         loop = mLoopsAndSwitches.at(i)->getAsLoopNode();
243                     }
244                     ASSERT(loop != nullptr);
245                     mMetadata->mDiscontinuousLoops.insert(loop);
246                 }
247                 break;
248                 case EOpKill:
249                 case EOpReturn:
250                     // A return or discard jumps out of all the enclosing loops
251                     if (!mLoopsAndSwitches.empty())
252                     {
253                         for (TIntermNode *intermNode : mLoopsAndSwitches)
254                         {
255                             TIntermLoop *loop = intermNode->getAsLoopNode();
256                             if (loop)
257                             {
258                                 mMetadata->mDiscontinuousLoops.insert(loop);
259                             }
260                         }
261                     }
262                     break;
263                 default:
264                     UNREACHABLE();
265             }
266         }
267 
268         return true;
269     }
270 
visitAggregate(Visit visit,TIntermAggregate * node)271     bool visitAggregate(Visit visit, TIntermAggregate *node) override
272     {
273         if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
274         {
275             size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
276             ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
277 
278             if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph)
279             {
280                 onGradientLoop();
281             }
282         }
283 
284         return true;
285     }
286 
visitSwitch(Visit visit,TIntermSwitch * node)287     bool visitSwitch(Visit visit, TIntermSwitch *node) override
288     {
289         if (visit == PreVisit)
290         {
291             mLoopsAndSwitches.push_back(node);
292         }
293         else if (visit == PostVisit)
294         {
295             ASSERT(mLoopsAndSwitches.back() == node);
296             mLoopsAndSwitches.pop_back();
297         }
298         return true;
299     }
300 
301   private:
302     MetadataList *mMetadataList;
303     ASTMetadataHLSL *mMetadata;
304     size_t mIndex;
305     const CallDAG &mDag;
306 
307     std::vector<TIntermNode *> mLoopsAndSwitches;
308     std::vector<TIntermIfElse *> mIfs;
309 };
310 
311 // Tags all the functions called in a discontinuous loop
312 class PushDiscontinuousLoops : public TIntermTraverser
313 {
314   public:
PushDiscontinuousLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)315     PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
316         : TIntermTraverser(true, true, true),
317           mMetadataList(metadataList),
318           mMetadata(&(*metadataList)[index]),
319           mIndex(index),
320           mDag(dag),
321           mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
322     {}
323 
traverse(TIntermFunctionDefinition * node)324     void traverse(TIntermFunctionDefinition *node)
325     {
326         node->traverse(this);
327         ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
328     }
329 
visitLoop(Visit visit,TIntermLoop * loop)330     bool visitLoop(Visit visit, TIntermLoop *loop) override
331     {
332         bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
333 
334         if (visit == PreVisit && isDiscontinuous)
335         {
336             mNestedDiscont++;
337         }
338         else if (visit == PostVisit && isDiscontinuous)
339         {
340             mNestedDiscont--;
341         }
342 
343         return true;
344     }
345 
visitAggregate(Visit visit,TIntermAggregate * node)346     bool visitAggregate(Visit visit, TIntermAggregate *node) override
347     {
348         switch (node->getOp())
349         {
350             case EOpCallFunctionInAST:
351                 if (visit == PreVisit && mNestedDiscont > 0)
352                 {
353                     size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
354                     ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
355 
356                     (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
357                 }
358                 break;
359             default:
360                 break;
361         }
362         return true;
363     }
364 
365   private:
366     MetadataList *mMetadataList;
367     ASTMetadataHLSL *mMetadata;
368     size_t mIndex;
369     const CallDAG &mDag;
370 
371     int mNestedDiscont;
372 };
373 }  // namespace
374 
hasGradientInCallGraph(TIntermLoop * node)375 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
376 {
377     return mControlFlowsContainingGradient.count(node) > 0;
378 }
379 
hasGradientLoop(TIntermIfElse * node)380 bool ASTMetadataHLSL::hasGradientLoop(TIntermIfElse *node)
381 {
382     return mIfsContainingGradientLoop.count(node) > 0;
383 }
384 
CreateASTMetadataHLSL(TIntermNode * root,const CallDAG & callDag)385 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
386 {
387     MetadataList metadataList(callDag.size());
388 
389     // Compute all the information related to when gradient operations are used.
390     // We want to know for each function and control flow operation if they have
391     // a gradient operation in their call graph (shortened to "using a gradient"
392     // in the rest of the file).
393     //
394     // This computation is logically split in three steps:
395     //  1 - For each function compute if it uses a gradient in its body, ignoring
396     // calls to other user-defined functions.
397     //  2 - For each function determine if it uses a gradient in its call graph,
398     // using the result of step 1 and the CallDAG to know its callees.
399     //  3 - For each control flow statement of each function, check if it uses a
400     // gradient in the function's body, or if it calls a user-defined function that
401     // uses a gradient.
402     //
403     // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3
404     // for leaves first, then going down the tree. This is correct because 1 doesn't
405     // depend on other functions, and 2 and 3 depend only on callees.
406     for (size_t i = 0; i < callDag.size(); i++)
407     {
408         PullGradient pull(&metadataList, i, callDag);
409         pull.traverse(callDag.getRecordFromIndex(i).node);
410     }
411 
412     // Compute which loops are discontinuous and which function are called in
413     // these loops. The same way computing gradient usage is a "pull" process,
414     // computing "bing used in a discont. loop" is a push process. However we also
415     // need to know what ifs have a discontinuous loop inside so we do the same type
416     // of callgraph analysis as for the gradient.
417 
418     // First compute which loops are discontinuous (no specific order) and pull
419     // the ifs and functions using a gradient loop.
420     for (size_t i = 0; i < callDag.size(); i++)
421     {
422         PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag);
423         pull.traverse(callDag.getRecordFromIndex(i).node);
424     }
425 
426     // Then push the information to callees, either from the a local discontinuous
427     // loop or from the caller being called in a discontinuous loop already
428     for (size_t i = callDag.size(); i-- > 0;)
429     {
430         PushDiscontinuousLoops push(&metadataList, i, callDag);
431         push.traverse(callDag.getRecordFromIndex(i).node);
432     }
433 
434     // We create "Lod0" version of functions with the gradient operations replaced
435     // by non-gradient operations so that the D3D compiler is happier with discont
436     // loops.
437     for (auto &metadata : metadataList)
438     {
439         metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
440     }
441 
442     return metadataList;
443 }
444 
445 }  // namespace sh
446