1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_util/IntermTraverse.h"
8 
9 #include "compiler/translator/Compiler.h"
10 #include "compiler/translator/InfoSink.h"
11 #include "compiler/translator/SymbolTable.h"
12 #include "compiler/translator/tree_util/IntermNode_util.h"
13 
14 namespace sh
15 {
16 
17 // Traverse the intermediate representation tree, and call a node type specific visit function for
18 // each node. Traversal is done recursively through the node member function traverse(). Nodes with
19 // children can have their whole subtree skipped if preVisit is turned on and the type specific
20 // function returns false.
21 template <typename T>
traverse(T * node)22 void TIntermTraverser::traverse(T *node)
23 {
24     ScopedNodeInTraversalPath addToPath(this, node);
25     if (!addToPath.isWithinDepthLimit())
26         return;
27 
28     bool visit = true;
29 
30     // Visit the node before children if pre-visiting.
31     if (preVisit)
32         visit = node->visit(PreVisit, this);
33 
34     if (visit)
35     {
36         size_t childIndex = 0;
37         size_t childCount = node->getChildCount();
38 
39         while (childIndex < childCount && visit)
40         {
41             mCurrentChildIndex = childIndex;
42             node->getChildNode(childIndex)->traverse(this);
43             mCurrentChildIndex = childIndex;
44 
45             if (inVisit && childIndex != childCount - 1)
46             {
47                 visit = node->visit(InVisit, this);
48             }
49             ++childIndex;
50         }
51 
52         if (visit && postVisit)
53             node->visit(PostVisit, this);
54     }
55 }
56 
57 // Instantiate template for RewriteAtomicFunctionExpressions, in case this gets inlined thus not
58 // exported from the TU.
59 template void TIntermTraverser::traverse(TIntermNode *);
60 
traverse(TIntermTraverser * it)61 void TIntermNode::traverse(TIntermTraverser *it)
62 {
63     it->traverse(this);
64 }
65 
traverse(TIntermTraverser * it)66 void TIntermSymbol::traverse(TIntermTraverser *it)
67 {
68     TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
69     it->visitSymbol(this);
70 }
71 
traverse(TIntermTraverser * it)72 void TIntermConstantUnion::traverse(TIntermTraverser *it)
73 {
74     TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
75     it->visitConstantUnion(this);
76 }
77 
traverse(TIntermTraverser * it)78 void TIntermFunctionPrototype::traverse(TIntermTraverser *it)
79 {
80     TIntermTraverser::ScopedNodeInTraversalPath addToPath(it, this);
81     it->visitFunctionPrototype(this);
82 }
83 
traverse(TIntermTraverser * it)84 void TIntermBinary::traverse(TIntermTraverser *it)
85 {
86     it->traverseBinary(this);
87 }
88 
traverse(TIntermTraverser * it)89 void TIntermUnary::traverse(TIntermTraverser *it)
90 {
91     it->traverseUnary(this);
92 }
93 
traverse(TIntermTraverser * it)94 void TIntermFunctionDefinition::traverse(TIntermTraverser *it)
95 {
96     it->traverseFunctionDefinition(this);
97 }
98 
traverse(TIntermTraverser * it)99 void TIntermBlock::traverse(TIntermTraverser *it)
100 {
101     it->traverseBlock(this);
102 }
103 
traverse(TIntermTraverser * it)104 void TIntermAggregate::traverse(TIntermTraverser *it)
105 {
106     it->traverseAggregate(this);
107 }
108 
traverse(TIntermTraverser * it)109 void TIntermLoop::traverse(TIntermTraverser *it)
110 {
111     it->traverseLoop(this);
112 }
113 
traverse(TIntermTraverser * it)114 void TIntermPreprocessorDirective::traverse(TIntermTraverser *it)
115 {
116     it->visitPreprocessorDirective(this);
117 }
118 
visit(Visit visit,TIntermTraverser * it)119 bool TIntermSymbol::visit(Visit visit, TIntermTraverser *it)
120 {
121     it->visitSymbol(this);
122     return false;
123 }
124 
visit(Visit visit,TIntermTraverser * it)125 bool TIntermConstantUnion::visit(Visit visit, TIntermTraverser *it)
126 {
127     it->visitConstantUnion(this);
128     return false;
129 }
130 
visit(Visit visit,TIntermTraverser * it)131 bool TIntermFunctionPrototype::visit(Visit visit, TIntermTraverser *it)
132 {
133     it->visitFunctionPrototype(this);
134     return false;
135 }
136 
visit(Visit visit,TIntermTraverser * it)137 bool TIntermFunctionDefinition::visit(Visit visit, TIntermTraverser *it)
138 {
139     return it->visitFunctionDefinition(visit, this);
140 }
141 
visit(Visit visit,TIntermTraverser * it)142 bool TIntermUnary::visit(Visit visit, TIntermTraverser *it)
143 {
144     return it->visitUnary(visit, this);
145 }
146 
visit(Visit visit,TIntermTraverser * it)147 bool TIntermSwizzle::visit(Visit visit, TIntermTraverser *it)
148 {
149     return it->visitSwizzle(visit, this);
150 }
151 
visit(Visit visit,TIntermTraverser * it)152 bool TIntermBinary::visit(Visit visit, TIntermTraverser *it)
153 {
154     return it->visitBinary(visit, this);
155 }
156 
visit(Visit visit,TIntermTraverser * it)157 bool TIntermTernary::visit(Visit visit, TIntermTraverser *it)
158 {
159     return it->visitTernary(visit, this);
160 }
161 
visit(Visit visit,TIntermTraverser * it)162 bool TIntermAggregate::visit(Visit visit, TIntermTraverser *it)
163 {
164     return it->visitAggregate(visit, this);
165 }
166 
visit(Visit visit,TIntermTraverser * it)167 bool TIntermDeclaration::visit(Visit visit, TIntermTraverser *it)
168 {
169     return it->visitDeclaration(visit, this);
170 }
171 
visit(Visit visit,TIntermTraverser * it)172 bool TIntermGlobalQualifierDeclaration::visit(Visit visit, TIntermTraverser *it)
173 {
174     return it->visitGlobalQualifierDeclaration(visit, this);
175 }
176 
visit(Visit visit,TIntermTraverser * it)177 bool TIntermBlock::visit(Visit visit, TIntermTraverser *it)
178 {
179     return it->visitBlock(visit, this);
180 }
181 
visit(Visit visit,TIntermTraverser * it)182 bool TIntermIfElse::visit(Visit visit, TIntermTraverser *it)
183 {
184     return it->visitIfElse(visit, this);
185 }
186 
visit(Visit visit,TIntermTraverser * it)187 bool TIntermLoop::visit(Visit visit, TIntermTraverser *it)
188 {
189     return it->visitLoop(visit, this);
190 }
191 
visit(Visit visit,TIntermTraverser * it)192 bool TIntermBranch::visit(Visit visit, TIntermTraverser *it)
193 {
194     return it->visitBranch(visit, this);
195 }
196 
visit(Visit visit,TIntermTraverser * it)197 bool TIntermSwitch::visit(Visit visit, TIntermTraverser *it)
198 {
199     return it->visitSwitch(visit, this);
200 }
201 
visit(Visit visit,TIntermTraverser * it)202 bool TIntermCase::visit(Visit visit, TIntermTraverser *it)
203 {
204     return it->visitCase(visit, this);
205 }
206 
visit(Visit visit,TIntermTraverser * it)207 bool TIntermPreprocessorDirective::visit(Visit visit, TIntermTraverser *it)
208 {
209     it->visitPreprocessorDirective(this);
210     return false;
211 }
212 
TIntermTraverser(bool preVisit,bool inVisit,bool postVisit,TSymbolTable * symbolTable)213 TIntermTraverser::TIntermTraverser(bool preVisit,
214                                    bool inVisit,
215                                    bool postVisit,
216                                    TSymbolTable *symbolTable)
217     : preVisit(preVisit),
218       inVisit(inVisit),
219       postVisit(postVisit),
220       mMaxDepth(0),
221       mMaxAllowedDepth(std::numeric_limits<int>::max()),
222       mInGlobalScope(true),
223       mSymbolTable(symbolTable),
224       mCurrentChildIndex(0)
225 {
226     // Only enabling inVisit is not supported.
227     ASSERT(!(inVisit && !preVisit && !postVisit));
228 }
229 
~TIntermTraverser()230 TIntermTraverser::~TIntermTraverser() {}
231 
setMaxAllowedDepth(int depth)232 void TIntermTraverser::setMaxAllowedDepth(int depth)
233 {
234     mMaxAllowedDepth = depth;
235 }
236 
getParentBlock() const237 const TIntermBlock *TIntermTraverser::getParentBlock() const
238 {
239     if (!mParentBlockStack.empty())
240     {
241         return mParentBlockStack.back().node;
242     }
243     return nullptr;
244 }
245 
pushParentBlock(TIntermBlock * node)246 void TIntermTraverser::pushParentBlock(TIntermBlock *node)
247 {
248     mParentBlockStack.push_back(ParentBlock(node, 0));
249 }
250 
incrementParentBlockPos()251 void TIntermTraverser::incrementParentBlockPos()
252 {
253     ++mParentBlockStack.back().pos;
254 }
255 
popParentBlock()256 void TIntermTraverser::popParentBlock()
257 {
258     ASSERT(!mParentBlockStack.empty());
259     mParentBlockStack.pop_back();
260 }
261 
insertStatementsInParentBlock(const TIntermSequence & insertions)262 void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &insertions)
263 {
264     TIntermSequence emptyInsertionsAfter;
265     insertStatementsInParentBlock(insertions, emptyInsertionsAfter);
266 }
267 
insertStatementsInParentBlock(const TIntermSequence & insertionsBefore,const TIntermSequence & insertionsAfter)268 void TIntermTraverser::insertStatementsInParentBlock(const TIntermSequence &insertionsBefore,
269                                                      const TIntermSequence &insertionsAfter)
270 {
271     ASSERT(!mParentBlockStack.empty());
272     ParentBlock &parentBlock = mParentBlockStack.back();
273     if (mPath.back() == parentBlock.node)
274     {
275         ASSERT(mParentBlockStack.size() >= 2u);
276         // The current node is a block node, so the parent block is not the topmost one in the block
277         // stack, but the one below that.
278         parentBlock = mParentBlockStack.at(mParentBlockStack.size() - 2u);
279     }
280     NodeInsertMultipleEntry insert(parentBlock.node, parentBlock.pos, insertionsBefore,
281                                    insertionsAfter);
282     mInsertions.push_back(insert);
283 }
284 
insertStatementInParentBlock(TIntermNode * statement)285 void TIntermTraverser::insertStatementInParentBlock(TIntermNode *statement)
286 {
287     TIntermSequence insertions;
288     insertions.push_back(statement);
289     insertStatementsInParentBlock(insertions);
290 }
291 
insertStatementsInBlockAtPosition(TIntermBlock * parent,size_t position,const TIntermSequence & insertionsBefore,const TIntermSequence & insertionsAfter)292 void TIntermTraverser::insertStatementsInBlockAtPosition(TIntermBlock *parent,
293                                                          size_t position,
294                                                          const TIntermSequence &insertionsBefore,
295                                                          const TIntermSequence &insertionsAfter)
296 {
297     ASSERT(parent);
298     ASSERT(position >= 0);
299     ASSERT(position < parent->getChildCount());
300 
301     mInsertions.emplace_back(parent, position, insertionsBefore, insertionsAfter);
302 }
303 
setInFunctionCallOutParameter(bool inOutParameter)304 void TLValueTrackingTraverser::setInFunctionCallOutParameter(bool inOutParameter)
305 {
306     mInFunctionCallOutParameter = inOutParameter;
307 }
308 
isInFunctionCallOutParameter() const309 bool TLValueTrackingTraverser::isInFunctionCallOutParameter() const
310 {
311     return mInFunctionCallOutParameter;
312 }
313 
traverseBinary(TIntermBinary * node)314 void TIntermTraverser::traverseBinary(TIntermBinary *node)
315 {
316     traverse(node);
317 }
318 
traverseBinary(TIntermBinary * node)319 void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
320 {
321     ScopedNodeInTraversalPath addToPath(this, node);
322     if (!addToPath.isWithinDepthLimit())
323         return;
324 
325     bool visit = true;
326 
327     // visit the node before children if pre-visiting.
328     if (preVisit)
329         visit = node->visit(PreVisit, this);
330 
331     // Visit the children, in the right order.
332     if (visit)
333     {
334         if (node->isAssignment())
335         {
336             ASSERT(!isLValueRequiredHere());
337             setOperatorRequiresLValue(true);
338         }
339 
340         node->getLeft()->traverse(this);
341 
342         if (node->isAssignment())
343             setOperatorRequiresLValue(false);
344 
345         if (inVisit)
346             visit = node->visit(InVisit, this);
347 
348         if (visit)
349         {
350             // Some binary operations like indexing can be inside an expression which must be an
351             // l-value.
352             bool parentOperatorRequiresLValue     = operatorRequiresLValue();
353             bool parentInFunctionCallOutParameter = isInFunctionCallOutParameter();
354 
355             // Index is not required to be an l-value even when the surrounding expression is
356             // required to be an l-value.
357             TOperator op = node->getOp();
358             if (op == EOpIndexDirect || op == EOpIndexDirectInterfaceBlock ||
359                 op == EOpIndexDirectStruct || op == EOpIndexIndirect)
360             {
361                 setOperatorRequiresLValue(false);
362                 setInFunctionCallOutParameter(false);
363             }
364 
365             node->getRight()->traverse(this);
366 
367             setOperatorRequiresLValue(parentOperatorRequiresLValue);
368             setInFunctionCallOutParameter(parentInFunctionCallOutParameter);
369 
370             // Visit the node after the children, if requested and the traversal
371             // hasn't been cancelled yet.
372             if (postVisit)
373                 visit = node->visit(PostVisit, this);
374         }
375     }
376 }
377 
traverseUnary(TIntermUnary * node)378 void TIntermTraverser::traverseUnary(TIntermUnary *node)
379 {
380     traverse(node);
381 }
382 
traverseUnary(TIntermUnary * node)383 void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
384 {
385     ScopedNodeInTraversalPath addToPath(this, node);
386     if (!addToPath.isWithinDepthLimit())
387         return;
388 
389     bool visit = true;
390 
391     if (preVisit)
392         visit = node->visit(PreVisit, this);
393 
394     if (visit)
395     {
396         ASSERT(!operatorRequiresLValue());
397         switch (node->getOp())
398         {
399             case EOpPostIncrement:
400             case EOpPostDecrement:
401             case EOpPreIncrement:
402             case EOpPreDecrement:
403                 setOperatorRequiresLValue(true);
404                 break;
405             default:
406                 break;
407         }
408 
409         node->getOperand()->traverse(this);
410 
411         setOperatorRequiresLValue(false);
412 
413         if (postVisit)
414             visit = node->visit(PostVisit, this);
415     }
416 }
417 
418 // Traverse a function definition node. This keeps track of global scope.
traverseFunctionDefinition(TIntermFunctionDefinition * node)419 void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
420 {
421     ScopedNodeInTraversalPath addToPath(this, node);
422     if (!addToPath.isWithinDepthLimit())
423         return;
424 
425     bool visit = true;
426 
427     if (preVisit)
428         visit = node->visit(PreVisit, this);
429 
430     if (visit)
431     {
432         mCurrentChildIndex = 0;
433         node->getFunctionPrototype()->traverse(this);
434         mCurrentChildIndex = 0;
435 
436         if (inVisit)
437             visit = node->visit(InVisit, this);
438         if (visit)
439         {
440             mInGlobalScope     = false;
441             mCurrentChildIndex = 1;
442             node->getBody()->traverse(this);
443             mCurrentChildIndex = 1;
444             mInGlobalScope     = true;
445             if (postVisit)
446                 visit = node->visit(PostVisit, this);
447         }
448     }
449 }
450 
451 // Traverse a block node. This keeps track of the position of traversed child nodes within the block
452 // so that nodes may be inserted before or after them.
traverseBlock(TIntermBlock * node)453 void TIntermTraverser::traverseBlock(TIntermBlock *node)
454 {
455     ScopedNodeInTraversalPath addToPath(this, node);
456     if (!addToPath.isWithinDepthLimit())
457         return;
458 
459     pushParentBlock(node);
460 
461     bool visit = true;
462 
463     TIntermSequence *sequence = node->getSequence();
464 
465     if (preVisit)
466         visit = node->visit(PreVisit, this);
467 
468     if (visit)
469     {
470         for (size_t childIndex = 0; childIndex < sequence->size(); ++childIndex)
471         {
472             TIntermNode *child = (*sequence)[childIndex];
473             if (visit)
474             {
475                 mCurrentChildIndex = childIndex;
476                 child->traverse(this);
477                 mCurrentChildIndex = childIndex;
478 
479                 if (inVisit)
480                 {
481                     if (child != sequence->back())
482                         visit = node->visit(InVisit, this);
483                 }
484 
485                 incrementParentBlockPos();
486             }
487         }
488 
489         if (visit && postVisit)
490             visit = node->visit(PostVisit, this);
491     }
492 
493     popParentBlock();
494 }
495 
traverseAggregate(TIntermAggregate * node)496 void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
497 {
498     traverse(node);
499 }
500 
CompareInsertion(const NodeInsertMultipleEntry & a,const NodeInsertMultipleEntry & b)501 bool TIntermTraverser::CompareInsertion(const NodeInsertMultipleEntry &a,
502                                         const NodeInsertMultipleEntry &b)
503 {
504     if (a.parent != b.parent)
505     {
506         return a.parent < b.parent;
507     }
508     return a.position < b.position;
509 }
510 
updateTree(TCompiler * compiler,TIntermNode * node)511 bool TIntermTraverser::updateTree(TCompiler *compiler, TIntermNode *node)
512 {
513     // Sort the insertions so that insertion position is increasing and same position insertions are
514     // not reordered. The insertions are processed in reverse order so that multiple insertions to
515     // the same parent node are handled correctly.
516     std::stable_sort(mInsertions.begin(), mInsertions.end(), CompareInsertion);
517     for (size_t ii = 0; ii < mInsertions.size(); ++ii)
518     {
519         // If two insertions are to the same position, insert them in the order they were specified.
520         // The std::stable_sort call above will automatically guarantee this.
521         const NodeInsertMultipleEntry &insertion = mInsertions[mInsertions.size() - ii - 1];
522         ASSERT(insertion.parent);
523         if (!insertion.insertionsAfter.empty())
524         {
525             bool inserted = insertion.parent->insertChildNodes(insertion.position + 1,
526                                                                insertion.insertionsAfter);
527             ASSERT(inserted);
528         }
529         if (!insertion.insertionsBefore.empty())
530         {
531             bool inserted =
532                 insertion.parent->insertChildNodes(insertion.position, insertion.insertionsBefore);
533             ASSERT(inserted);
534         }
535     }
536     for (size_t ii = 0; ii < mReplacements.size(); ++ii)
537     {
538         const NodeUpdateEntry &replacement = mReplacements[ii];
539         ASSERT(replacement.parent);
540         bool replaced =
541             replacement.parent->replaceChildNode(replacement.original, replacement.replacement);
542         ASSERT(replaced);
543 
544         if (!replacement.originalBecomesChildOfReplacement)
545         {
546             // In AST traversing, a parent is visited before its children.
547             // After we replace a node, if its immediate child is to
548             // be replaced, we need to make sure we don't update the replaced
549             // node; instead, we update the replacement node.
550             for (size_t jj = ii + 1; jj < mReplacements.size(); ++jj)
551             {
552                 NodeUpdateEntry &replacement2 = mReplacements[jj];
553                 if (replacement2.parent == replacement.original)
554                     replacement2.parent = replacement.replacement;
555             }
556         }
557     }
558     for (size_t ii = 0; ii < mMultiReplacements.size(); ++ii)
559     {
560         const NodeReplaceWithMultipleEntry &replacement = mMultiReplacements[ii];
561         ASSERT(replacement.parent);
562         bool replaced = replacement.parent->replaceChildNodeWithMultiple(replacement.original,
563                                                                          replacement.replacements);
564         ASSERT(replaced);
565     }
566 
567     clearReplacementQueue();
568 
569     return compiler->validateAST(node);
570 }
571 
clearReplacementQueue()572 void TIntermTraverser::clearReplacementQueue()
573 {
574     mReplacements.clear();
575     mMultiReplacements.clear();
576     mInsertions.clear();
577 }
578 
queueReplacement(TIntermNode * replacement,OriginalNode originalStatus)579 void TIntermTraverser::queueReplacement(TIntermNode *replacement, OriginalNode originalStatus)
580 {
581     queueReplacementWithParent(getParentNode(), mPath.back(), replacement, originalStatus);
582 }
583 
queueReplacementWithParent(TIntermNode * parent,TIntermNode * original,TIntermNode * replacement,OriginalNode originalStatus)584 void TIntermTraverser::queueReplacementWithParent(TIntermNode *parent,
585                                                   TIntermNode *original,
586                                                   TIntermNode *replacement,
587                                                   OriginalNode originalStatus)
588 {
589     bool originalBecomesChild = (originalStatus == OriginalNode::BECOMES_CHILD);
590     mReplacements.push_back(NodeUpdateEntry(parent, original, replacement, originalBecomesChild));
591 }
592 
TLValueTrackingTraverser(bool preVisitIn,bool inVisitIn,bool postVisitIn,TSymbolTable * symbolTable)593 TLValueTrackingTraverser::TLValueTrackingTraverser(bool preVisitIn,
594                                                    bool inVisitIn,
595                                                    bool postVisitIn,
596                                                    TSymbolTable *symbolTable)
597     : TIntermTraverser(preVisitIn, inVisitIn, postVisitIn, symbolTable),
598       mOperatorRequiresLValue(false),
599       mInFunctionCallOutParameter(false)
600 {
601     ASSERT(symbolTable);
602 }
603 
traverseAggregate(TIntermAggregate * node)604 void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
605 {
606     ScopedNodeInTraversalPath addToPath(this, node);
607     if (!addToPath.isWithinDepthLimit())
608         return;
609 
610     bool visit = true;
611 
612     TIntermSequence *sequence = node->getSequence();
613 
614     if (preVisit)
615         visit = node->visit(PreVisit, this);
616 
617     if (visit)
618     {
619         size_t paramIndex = 0u;
620         for (auto *child : *sequence)
621         {
622             if (visit)
623             {
624                 if (node->getFunction())
625                 {
626                     // Both built-ins and user defined functions should have the function symbol
627                     // set.
628                     ASSERT(paramIndex < node->getFunction()->getParamCount());
629                     TQualifier qualifier =
630                         node->getFunction()->getParam(paramIndex)->getType().getQualifier();
631                     setInFunctionCallOutParameter(qualifier == EvqOut || qualifier == EvqInOut);
632                     ++paramIndex;
633                 }
634                 else
635                 {
636                     ASSERT(node->isConstructor());
637                 }
638                 child->traverse(this);
639                 if (inVisit)
640                 {
641                     if (child != sequence->back())
642                         visit = node->visit(InVisit, this);
643                 }
644             }
645         }
646         setInFunctionCallOutParameter(false);
647 
648         if (visit && postVisit)
649             visit = node->visit(PostVisit, this);
650     }
651 }
652 
traverseLoop(TIntermLoop * node)653 void TIntermTraverser::traverseLoop(TIntermLoop *node)
654 {
655     traverse(node);
656 }
657 }  // namespace sh
658