1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 // Analysis of the AST needed for HLSL generation
8
9 #include "compiler/translator/ASTMetadataHLSL.h"
10
11 #include "compiler/translator/CallDAG.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/IntermTraverse.h"
14
15 namespace sh
16 {
17
18 namespace
19 {
20
21 // Class used to traverse the AST of a function definition, checking if the
22 // function uses a gradient, and writing the set of control flow using gradients.
23 // It assumes that the analysis has already been made for the function's
24 // callees.
25 class PullGradient : public TIntermTraverser
26 {
27 public:
PullGradient(MetadataList * metadataList,size_t index,const CallDAG & dag)28 PullGradient(MetadataList *metadataList, size_t index, const CallDAG &dag)
29 : TIntermTraverser(true, false, true),
30 mMetadataList(metadataList),
31 mMetadata(&(*metadataList)[index]),
32 mIndex(index),
33 mDag(dag)
34 {
35 ASSERT(index < metadataList->size());
36
37 // ESSL 100 builtin gradient functions
38 mGradientBuiltinFunctions.insert(ImmutableString("texture2D"));
39 mGradientBuiltinFunctions.insert(ImmutableString("texture2DProj"));
40 mGradientBuiltinFunctions.insert(ImmutableString("textureCube"));
41
42 // ESSL 300 builtin gradient functions
43 mGradientBuiltinFunctions.insert(ImmutableString("dFdx"));
44 mGradientBuiltinFunctions.insert(ImmutableString("dFdy"));
45 mGradientBuiltinFunctions.insert(ImmutableString("fwidth"));
46 mGradientBuiltinFunctions.insert(ImmutableString("texture"));
47 mGradientBuiltinFunctions.insert(ImmutableString("textureProj"));
48 mGradientBuiltinFunctions.insert(ImmutableString("textureOffset"));
49 mGradientBuiltinFunctions.insert(ImmutableString("textureProjOffset"));
50
51 // ESSL 310 doesn't add builtin gradient functions
52 }
53
traverse(TIntermFunctionDefinition * node)54 void traverse(TIntermFunctionDefinition *node)
55 {
56 node->traverse(this);
57 ASSERT(mParents.empty());
58 }
59
60 // Called when a gradient operation or a call to a function using a gradient is found.
onGradient()61 void onGradient()
62 {
63 mMetadata->mUsesGradient = true;
64 // Mark the latest control flow as using a gradient.
65 if (!mParents.empty())
66 {
67 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
68 }
69 }
70
visitControlFlow(Visit visit,TIntermNode * node)71 void visitControlFlow(Visit visit, TIntermNode *node)
72 {
73 if (visit == PreVisit)
74 {
75 mParents.push_back(node);
76 }
77 else if (visit == PostVisit)
78 {
79 ASSERT(mParents.back() == node);
80 mParents.pop_back();
81 // A control flow's using a gradient means its parents are too.
82 if (mMetadata->mControlFlowsContainingGradient.count(node) > 0 && !mParents.empty())
83 {
84 mMetadata->mControlFlowsContainingGradient.insert(mParents.back());
85 }
86 }
87 }
88
visitLoop(Visit visit,TIntermLoop * loop)89 bool visitLoop(Visit visit, TIntermLoop *loop) override
90 {
91 visitControlFlow(visit, loop);
92 return true;
93 }
94
visitIfElse(Visit visit,TIntermIfElse * ifElse)95 bool visitIfElse(Visit visit, TIntermIfElse *ifElse) override
96 {
97 visitControlFlow(visit, ifElse);
98 return true;
99 }
100
visitAggregate(Visit visit,TIntermAggregate * node)101 bool visitAggregate(Visit visit, TIntermAggregate *node) override
102 {
103 if (visit == PreVisit)
104 {
105 if (node->getOp() == EOpCallFunctionInAST)
106 {
107 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
108 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
109
110 if ((*mMetadataList)[calleeIndex].mUsesGradient)
111 {
112 onGradient();
113 }
114 }
115 else if (BuiltInGroup::IsBuiltIn(node->getOp()) && !BuiltInGroup::IsMath(node->getOp()))
116 {
117 if (mGradientBuiltinFunctions.find(node->getFunction()->name()) !=
118 mGradientBuiltinFunctions.end())
119 {
120 onGradient();
121 }
122 }
123 }
124
125 return true;
126 }
127
128 private:
129 MetadataList *mMetadataList;
130 ASTMetadataHLSL *mMetadata;
131 size_t mIndex;
132 const CallDAG &mDag;
133
134 // Contains a stack of the control flow nodes that are parents of the node being
135 // currently visited. It is used to mark control flows using a gradient.
136 std::vector<TIntermNode *> mParents;
137
138 // A list of builtin functions that use gradients
139 std::set<ImmutableString> mGradientBuiltinFunctions;
140 };
141
142 // Traverses the AST of a function definition to compute the the discontinuous loops
143 // and the if statements containing gradient loops. It assumes that the gradient loops
144 // (loops that contain a gradient) have already been computed and that it has already
145 // traversed the current function's callees.
146 class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
147 {
148 public:
PullComputeDiscontinuousAndGradientLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)149 PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList,
150 size_t index,
151 const CallDAG &dag)
152 : TIntermTraverser(true, false, true),
153 mMetadataList(metadataList),
154 mMetadata(&(*metadataList)[index]),
155 mIndex(index),
156 mDag(dag)
157 {}
158
traverse(TIntermFunctionDefinition * node)159 void traverse(TIntermFunctionDefinition *node)
160 {
161 node->traverse(this);
162 ASSERT(mLoopsAndSwitches.empty());
163 ASSERT(mIfs.empty());
164 }
165
166 // Called when traversing a gradient loop or a call to a function with a
167 // gradient loop in its call graph.
onGradientLoop()168 void onGradientLoop()
169 {
170 mMetadata->mHasGradientLoopInCallGraph = true;
171 // Mark the latest if as using a discontinuous loop.
172 if (!mIfs.empty())
173 {
174 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
175 }
176 }
177
visitLoop(Visit visit,TIntermLoop * loop)178 bool visitLoop(Visit visit, TIntermLoop *loop) override
179 {
180 if (visit == PreVisit)
181 {
182 mLoopsAndSwitches.push_back(loop);
183
184 if (mMetadata->hasGradientInCallGraph(loop))
185 {
186 onGradientLoop();
187 }
188 }
189 else if (visit == PostVisit)
190 {
191 ASSERT(mLoopsAndSwitches.back() == loop);
192 mLoopsAndSwitches.pop_back();
193 }
194
195 return true;
196 }
197
visitIfElse(Visit visit,TIntermIfElse * node)198 bool visitIfElse(Visit visit, TIntermIfElse *node) override
199 {
200 if (visit == PreVisit)
201 {
202 mIfs.push_back(node);
203 }
204 else if (visit == PostVisit)
205 {
206 ASSERT(mIfs.back() == node);
207 mIfs.pop_back();
208 // An if using a discontinuous loop means its parents ifs are also discontinuous.
209 if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty())
210 {
211 mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
212 }
213 }
214
215 return true;
216 }
217
visitBranch(Visit visit,TIntermBranch * node)218 bool visitBranch(Visit visit, TIntermBranch *node) override
219 {
220 if (visit == PreVisit)
221 {
222 switch (node->getFlowOp())
223 {
224 case EOpBreak:
225 {
226 ASSERT(!mLoopsAndSwitches.empty());
227 TIntermLoop *loop = mLoopsAndSwitches.back()->getAsLoopNode();
228 if (loop != nullptr)
229 {
230 mMetadata->mDiscontinuousLoops.insert(loop);
231 }
232 }
233 break;
234 case EOpContinue:
235 {
236 ASSERT(!mLoopsAndSwitches.empty());
237 TIntermLoop *loop = nullptr;
238 size_t i = mLoopsAndSwitches.size();
239 while (loop == nullptr && i > 0)
240 {
241 --i;
242 loop = mLoopsAndSwitches.at(i)->getAsLoopNode();
243 }
244 ASSERT(loop != nullptr);
245 mMetadata->mDiscontinuousLoops.insert(loop);
246 }
247 break;
248 case EOpKill:
249 case EOpReturn:
250 // A return or discard jumps out of all the enclosing loops
251 if (!mLoopsAndSwitches.empty())
252 {
253 for (TIntermNode *intermNode : mLoopsAndSwitches)
254 {
255 TIntermLoop *loop = intermNode->getAsLoopNode();
256 if (loop)
257 {
258 mMetadata->mDiscontinuousLoops.insert(loop);
259 }
260 }
261 }
262 break;
263 default:
264 UNREACHABLE();
265 }
266 }
267
268 return true;
269 }
270
visitAggregate(Visit visit,TIntermAggregate * node)271 bool visitAggregate(Visit visit, TIntermAggregate *node) override
272 {
273 if (visit == PreVisit && node->getOp() == EOpCallFunctionInAST)
274 {
275 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
276 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
277
278 if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph)
279 {
280 onGradientLoop();
281 }
282 }
283
284 return true;
285 }
286
visitSwitch(Visit visit,TIntermSwitch * node)287 bool visitSwitch(Visit visit, TIntermSwitch *node) override
288 {
289 if (visit == PreVisit)
290 {
291 mLoopsAndSwitches.push_back(node);
292 }
293 else if (visit == PostVisit)
294 {
295 ASSERT(mLoopsAndSwitches.back() == node);
296 mLoopsAndSwitches.pop_back();
297 }
298 return true;
299 }
300
301 private:
302 MetadataList *mMetadataList;
303 ASTMetadataHLSL *mMetadata;
304 size_t mIndex;
305 const CallDAG &mDag;
306
307 std::vector<TIntermNode *> mLoopsAndSwitches;
308 std::vector<TIntermIfElse *> mIfs;
309 };
310
311 // Tags all the functions called in a discontinuous loop
312 class PushDiscontinuousLoops : public TIntermTraverser
313 {
314 public:
PushDiscontinuousLoops(MetadataList * metadataList,size_t index,const CallDAG & dag)315 PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
316 : TIntermTraverser(true, true, true),
317 mMetadataList(metadataList),
318 mMetadata(&(*metadataList)[index]),
319 mIndex(index),
320 mDag(dag),
321 mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
322 {}
323
traverse(TIntermFunctionDefinition * node)324 void traverse(TIntermFunctionDefinition *node)
325 {
326 node->traverse(this);
327 ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
328 }
329
visitLoop(Visit visit,TIntermLoop * loop)330 bool visitLoop(Visit visit, TIntermLoop *loop) override
331 {
332 bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
333
334 if (visit == PreVisit && isDiscontinuous)
335 {
336 mNestedDiscont++;
337 }
338 else if (visit == PostVisit && isDiscontinuous)
339 {
340 mNestedDiscont--;
341 }
342
343 return true;
344 }
345
visitAggregate(Visit visit,TIntermAggregate * node)346 bool visitAggregate(Visit visit, TIntermAggregate *node) override
347 {
348 switch (node->getOp())
349 {
350 case EOpCallFunctionInAST:
351 if (visit == PreVisit && mNestedDiscont > 0)
352 {
353 size_t calleeIndex = mDag.findIndex(node->getFunction()->uniqueId());
354 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
355
356 (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
357 }
358 break;
359 default:
360 break;
361 }
362 return true;
363 }
364
365 private:
366 MetadataList *mMetadataList;
367 ASTMetadataHLSL *mMetadata;
368 size_t mIndex;
369 const CallDAG &mDag;
370
371 int mNestedDiscont;
372 };
373 } // namespace
374
hasGradientInCallGraph(TIntermLoop * node)375 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
376 {
377 return mControlFlowsContainingGradient.count(node) > 0;
378 }
379
hasGradientLoop(TIntermIfElse * node)380 bool ASTMetadataHLSL::hasGradientLoop(TIntermIfElse *node)
381 {
382 return mIfsContainingGradientLoop.count(node) > 0;
383 }
384
CreateASTMetadataHLSL(TIntermNode * root,const CallDAG & callDag)385 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
386 {
387 MetadataList metadataList(callDag.size());
388
389 // Compute all the information related to when gradient operations are used.
390 // We want to know for each function and control flow operation if they have
391 // a gradient operation in their call graph (shortened to "using a gradient"
392 // in the rest of the file).
393 //
394 // This computation is logically split in three steps:
395 // 1 - For each function compute if it uses a gradient in its body, ignoring
396 // calls to other user-defined functions.
397 // 2 - For each function determine if it uses a gradient in its call graph,
398 // using the result of step 1 and the CallDAG to know its callees.
399 // 3 - For each control flow statement of each function, check if it uses a
400 // gradient in the function's body, or if it calls a user-defined function that
401 // uses a gradient.
402 //
403 // We take advantage of the call graph being a DAG and instead compute 1, 2 and 3
404 // for leaves first, then going down the tree. This is correct because 1 doesn't
405 // depend on other functions, and 2 and 3 depend only on callees.
406 for (size_t i = 0; i < callDag.size(); i++)
407 {
408 PullGradient pull(&metadataList, i, callDag);
409 pull.traverse(callDag.getRecordFromIndex(i).node);
410 }
411
412 // Compute which loops are discontinuous and which function are called in
413 // these loops. The same way computing gradient usage is a "pull" process,
414 // computing "bing used in a discont. loop" is a push process. However we also
415 // need to know what ifs have a discontinuous loop inside so we do the same type
416 // of callgraph analysis as for the gradient.
417
418 // First compute which loops are discontinuous (no specific order) and pull
419 // the ifs and functions using a gradient loop.
420 for (size_t i = 0; i < callDag.size(); i++)
421 {
422 PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag);
423 pull.traverse(callDag.getRecordFromIndex(i).node);
424 }
425
426 // Then push the information to callees, either from the a local discontinuous
427 // loop or from the caller being called in a discontinuous loop already
428 for (size_t i = callDag.size(); i-- > 0;)
429 {
430 PushDiscontinuousLoops push(&metadataList, i, callDag);
431 push.traverse(callDag.getRecordFromIndex(i).node);
432 }
433
434 // We create "Lod0" version of functions with the gradient operations replaced
435 // by non-gradient operations so that the D3D compiler is happier with discont
436 // loops.
437 for (auto &metadata : metadataList)
438 {
439 metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
440 }
441
442 return metadataList;
443 }
444
445 } // namespace sh
446