1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 
9 #include "compiler/translator/Compiler.h"
10 #include "compiler/translator/SymbolTable.h"
11 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
12 #include "compiler/translator/tree_util/AsNode.h"
13 
14 #define GUARD2(cond, failVal) \
15     do                        \
16     {                         \
17         if (!(cond))          \
18         {                     \
19             return failVal;   \
20         }                     \
21     } while (false)
22 
23 #define GUARD(cond) GUARD2(cond, nullptr)
24 
25 namespace sh
26 {
27 
28 template <typename T, typename U>
AllBits(T haystack,U needle)29 ANGLE_INLINE bool AllBits(T haystack, U needle)
30 {
31     return (haystack & needle) == needle;
32 }
33 
34 template <typename T, typename U>
AnyBits(T haystack,U needle)35 ANGLE_INLINE bool AnyBits(T haystack, U needle)
36 {
37     return (haystack & needle) != 0;
38 }
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 
BaseResult(BaseResult & other)42 TIntermRebuild::BaseResult::BaseResult(BaseResult &other)
43     : mAction(other.mAction),
44       mVisit(other.mVisit),
45       mSingle(other.mSingle),
46       mMulti(std::move(other.mMulti))
47 {}
48 
BaseResult(TIntermNode & node,VisitBits visit)49 TIntermRebuild::BaseResult::BaseResult(TIntermNode &node, VisitBits visit)
50     : mAction(Action::ReplaceSingle), mVisit(visit), mSingle(&node)
51 {}
52 
BaseResult(TIntermNode * node,VisitBits visit)53 TIntermRebuild::BaseResult::BaseResult(TIntermNode *node, VisitBits visit)
54     : mAction(node ? Action::ReplaceSingle : Action::Drop),
55       mVisit(node ? visit : VisitBits::Neither),
56       mSingle(node)
57 {}
58 
BaseResult(nullptr_t)59 TIntermRebuild::BaseResult::BaseResult(nullptr_t)
60     : mAction(Action::Drop), mVisit(VisitBits::Neither), mSingle(nullptr)
61 {}
62 
BaseResult(Fail)63 TIntermRebuild::BaseResult::BaseResult(Fail)
64     : mAction(Action::Fail), mVisit(VisitBits::Neither), mSingle(nullptr)
65 {}
66 
BaseResult(std::vector<TIntermNode * > && nodes)67 TIntermRebuild::BaseResult::BaseResult(std::vector<TIntermNode *> &&nodes)
68     : mAction(Action::ReplaceMulti),
69       mVisit(VisitBits::Neither),
70       mSingle(nullptr),
71       mMulti(std::move(nodes))
72 {}
73 
moveAssignImpl(BaseResult & other)74 void TIntermRebuild::BaseResult::moveAssignImpl(BaseResult &other)
75 {
76     mAction = other.mAction;
77     mVisit  = other.mVisit;
78     mSingle = other.mSingle;
79     mMulti  = std::move(other.mMulti);
80 }
81 
Multi(std::vector<TIntermNode * > && nodes)82 TIntermRebuild::BaseResult TIntermRebuild::BaseResult::Multi(std::vector<TIntermNode *> &&nodes)
83 {
84     auto it = std::remove(nodes.begin(), nodes.end(), nullptr);
85     nodes.erase(it, nodes.end());
86     return std::move(nodes);
87 }
88 
isFail() const89 bool TIntermRebuild::BaseResult::isFail() const
90 {
91     return mAction == Action::Fail;
92 }
93 
isDrop() const94 bool TIntermRebuild::BaseResult::isDrop() const
95 {
96     return mAction == Action::Drop;
97 }
98 
single() const99 TIntermNode *TIntermRebuild::BaseResult::single() const
100 {
101     return mSingle;
102 }
103 
multi() const104 const std::vector<TIntermNode *> *TIntermRebuild::BaseResult::multi() const
105 {
106     if (mAction == Action::ReplaceMulti)
107     {
108         return &mMulti;
109     }
110     return nullptr;
111 }
112 
113 ////////////////////////////////////////////////////////////////////////////////
114 
115 using PreResult = TIntermRebuild::PreResult;
116 
PreResult(TIntermNode & node,VisitBits visit)117 PreResult::PreResult(TIntermNode &node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(TIntermNode * node,VisitBits visit)118 PreResult::PreResult(TIntermNode *node, VisitBits visit) : BaseResult(node, visit) {}
PreResult(nullptr_t)119 PreResult::PreResult(nullptr_t) : BaseResult(nullptr) {}
PreResult(Fail)120 PreResult::PreResult(Fail) : BaseResult(Fail()) {}
121 
PreResult(BaseResult && other)122 PreResult::PreResult(BaseResult &&other) : BaseResult(other) {}
PreResult(PreResult && other)123 PreResult::PreResult(PreResult &&other) : BaseResult(other) {}
124 
operator =(PreResult && other)125 void PreResult::operator=(PreResult &&other)
126 {
127     moveAssignImpl(other);
128 }
129 
130 ////////////////////////////////////////////////////////////////////////////////
131 
132 using PostResult = TIntermRebuild::PostResult;
133 
PostResult(TIntermNode & node)134 PostResult::PostResult(TIntermNode &node) : BaseResult(node, VisitBits::Neither) {}
PostResult(TIntermNode * node)135 PostResult::PostResult(TIntermNode *node) : BaseResult(node, VisitBits::Neither) {}
PostResult(nullptr_t)136 PostResult::PostResult(nullptr_t) : BaseResult(nullptr) {}
PostResult(Fail)137 PostResult::PostResult(Fail) : BaseResult(Fail()) {}
138 
PostResult(PostResult && other)139 PostResult::PostResult(PostResult &&other) : BaseResult(other) {}
PostResult(BaseResult && other)140 PostResult::PostResult(BaseResult &&other) : BaseResult(other) {}
141 
operator =(PostResult && other)142 void PostResult::operator=(PostResult &&other)
143 {
144     moveAssignImpl(other);
145 }
146 
147 ////////////////////////////////////////////////////////////////////////////////
148 
TIntermRebuild(TCompiler & compiler,bool preVisit,bool postVisit)149 TIntermRebuild::TIntermRebuild(TCompiler &compiler, bool preVisit, bool postVisit)
150     : mCompiler(compiler),
151       mSymbolTable(compiler.getSymbolTable()),
152       mPreVisit(preVisit),
153       mPostVisit(postVisit)
154 {
155     ASSERT(preVisit || postVisit);
156 }
157 
~TIntermRebuild()158 TIntermRebuild::~TIntermRebuild()
159 {
160     ASSERT(!mNodeStack.value);
161     ASSERT(!mNodeStack.tail);
162 }
163 
getParentFunction() const164 const TFunction *TIntermRebuild::getParentFunction() const
165 {
166     return mParentFunc;
167 }
168 
getParentNode(size_t offset) const169 TIntermNode *TIntermRebuild::getParentNode(size_t offset) const
170 {
171     ASSERT(mNodeStack.tail);
172     auto parent = *mNodeStack.tail;
173     while (offset > 0)
174     {
175         --offset;
176         ASSERT(parent.tail);
177         parent = *parent.tail;
178     }
179     return parent.value;
180 }
181 
rebuildRoot(TIntermBlock & root)182 bool TIntermRebuild::rebuildRoot(TIntermBlock &root)
183 {
184     if (!rebuildInPlace(root))
185     {
186         return false;
187     }
188     return mCompiler.validateAST(&root);
189 }
190 
rebuildInPlace(TIntermAggregate & node)191 bool TIntermRebuild::rebuildInPlace(TIntermAggregate &node)
192 {
193     return rebuildInPlaceImpl(node);
194 }
195 
rebuildInPlace(TIntermBlock & node)196 bool TIntermRebuild::rebuildInPlace(TIntermBlock &node)
197 {
198     return rebuildInPlaceImpl(node);
199 }
200 
rebuildInPlace(TIntermDeclaration & node)201 bool TIntermRebuild::rebuildInPlace(TIntermDeclaration &node)
202 {
203     return rebuildInPlaceImpl(node);
204 }
205 
206 template <typename Node>
rebuildInPlaceImpl(Node & node)207 bool TIntermRebuild::rebuildInPlaceImpl(Node &node)
208 {
209     auto *newNode = traverseAnyAs<Node>(node);
210     if (!newNode)
211     {
212         return false;
213     }
214 
215     if (newNode != &node)
216     {
217         *node.getSequence() = std::move(*newNode->getSequence());
218     }
219 
220     return true;
221 }
222 
rebuild(TIntermNode & node)223 PostResult TIntermRebuild::rebuild(TIntermNode &node)
224 {
225     return traverseAny(node);
226 }
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 
230 template <typename Node>
traverseAnyAs(TIntermNode & node)231 Node *TIntermRebuild::traverseAnyAs(TIntermNode &node)
232 {
233     PostResult result(traverseAny(node));
234     if (result.mAction == Action::Fail || !result.mSingle)
235     {
236         return nullptr;
237     }
238     return asNode<Node>(result.mSingle);
239 }
240 
241 template <typename Node>
traverseAnyAs(TIntermNode & node,Node * & out)242 bool TIntermRebuild::traverseAnyAs(TIntermNode &node, Node *&out)
243 {
244     PostResult result(traverseAny(node));
245     if (result.mAction == Action::Fail || result.mAction == Action::ReplaceMulti)
246     {
247         return false;
248     }
249     if (!result.mSingle)
250     {
251         return true;
252     }
253     out = asNode<Node>(result.mSingle);
254     return out;
255 }
256 
traverseAggregateBaseChildren(TIntermAggregateBase & node)257 bool TIntermRebuild::traverseAggregateBaseChildren(TIntermAggregateBase &node)
258 {
259     auto *const children = node.getSequence();
260     ASSERT(children);
261     TIntermSequence newChildren;
262 
263     for (TIntermNode *child : *children)
264     {
265         ASSERT(child);
266         PostResult result(traverseAny(*child));
267 
268         switch (result.mAction)
269         {
270             case Action::ReplaceSingle:
271                 newChildren.push_back(result.mSingle);
272                 break;
273 
274             case Action::ReplaceMulti:
275                 for (TIntermNode *newNode : result.mMulti)
276                 {
277                     if (newNode)
278                     {
279                         newChildren.push_back(newNode);
280                     }
281                 }
282                 break;
283 
284             case Action::Drop:
285                 break;
286 
287             case Action::Fail:
288                 return false;
289         }
290     }
291 
292     *children = std::move(newChildren);
293 
294     return true;
295 }
296 
297 ////////////////////////////////////////////////////////////////////////////////
298 
299 struct TIntermRebuild::NodeStackGuard
300 {
301     ConsList<TIntermNode *> oldNodeStack;
302     ConsList<TIntermNode *> &nodeStack;
NodeStackGuardsh::TIntermRebuild::NodeStackGuard303     NodeStackGuard(ConsList<TIntermNode *> &nodeStack)
304         : oldNodeStack(nodeStack), nodeStack(nodeStack)
305     {}
~NodeStackGuardsh::TIntermRebuild::NodeStackGuard306     ~NodeStackGuard() { nodeStack = oldNodeStack; }
307 };
308 
traverseAny(TIntermNode & originalNode)309 PostResult TIntermRebuild::traverseAny(TIntermNode &originalNode)
310 {
311     PreResult preResult = traversePre(originalNode);
312     if (!preResult.mSingle)
313     {
314         ASSERT(preResult.mVisit == VisitBits::Neither);
315         return std::move(preResult);
316     }
317 
318     TIntermNode *currNode       = preResult.mSingle;
319     const VisitBits visit       = preResult.mVisit;
320     const NodeType currNodeType = getNodeType(*currNode);
321 
322     currNode = traverseChildren(currNodeType, originalNode, *currNode, visit);
323     if (!currNode)
324     {
325         return Fail();
326     }
327 
328     return traversePost(currNodeType, originalNode, *currNode, visit);
329 }
330 
traversePre(TIntermNode & originalNode)331 PreResult TIntermRebuild::traversePre(TIntermNode &originalNode)
332 {
333     if (!mPreVisit)
334     {
335         return {originalNode, VisitBits::Both};
336     }
337 
338     NodeStackGuard guard(mNodeStack);
339     mNodeStack = {&originalNode, &guard.oldNodeStack};
340 
341     const NodeType originalNodeType = getNodeType(originalNode);
342 
343     switch (originalNodeType)
344     {
345         case NodeType::Unknown:
346             ASSERT(false);
347             return Fail();
348         case NodeType::Symbol:
349             return visitSymbolPre(*originalNode.getAsSymbolNode());
350         case NodeType::ConstantUnion:
351             return visitConstantUnionPre(*originalNode.getAsConstantUnion());
352         case NodeType::FunctionPrototype:
353             return visitFunctionPrototypePre(*originalNode.getAsFunctionPrototypeNode());
354         case NodeType::PreprocessorDirective:
355             return visitPreprocessorDirectivePre(*originalNode.getAsPreprocessorDirective());
356         case NodeType::Unary:
357             return visitUnaryPre(*originalNode.getAsUnaryNode());
358         case NodeType::Binary:
359             return visitBinaryPre(*originalNode.getAsBinaryNode());
360         case NodeType::Ternary:
361             return visitTernaryPre(*originalNode.getAsTernaryNode());
362         case NodeType::Swizzle:
363             return visitSwizzlePre(*originalNode.getAsSwizzleNode());
364         case NodeType::IfElse:
365             return visitIfElsePre(*originalNode.getAsIfElseNode());
366         case NodeType::Switch:
367             return visitSwitchPre(*originalNode.getAsSwitchNode());
368         case NodeType::Case:
369             return visitCasePre(*originalNode.getAsCaseNode());
370         case NodeType::FunctionDefinition:
371             return visitFunctionDefinitionPre(*originalNode.getAsFunctionDefinition());
372         case NodeType::Aggregate:
373             return visitAggregatePre(*originalNode.getAsAggregate());
374         case NodeType::Block:
375             return visitBlockPre(*originalNode.getAsBlock());
376         case NodeType::GlobalQualifierDeclaration:
377             return visitGlobalQualifierDeclarationPre(
378                 *originalNode.getAsGlobalQualifierDeclarationNode());
379         case NodeType::Declaration:
380             return visitDeclarationPre(*originalNode.getAsDeclarationNode());
381         case NodeType::Loop:
382             return visitLoopPre(*originalNode.getAsLoopNode());
383         case NodeType::Branch:
384             return visitBranchPre(*originalNode.getAsBranchNode());
385     }
386 }
387 
traverseChildren(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)388 TIntermNode *TIntermRebuild::traverseChildren(NodeType currNodeType,
389                                               const TIntermNode &originalNode,
390                                               TIntermNode &currNode,
391                                               VisitBits visit)
392 {
393     if (!AnyBits(visit, VisitBits::Children))
394     {
395         return &currNode;
396     }
397 
398     if (AnyBits(visit, VisitBits::ChildrenRequiresSame) && &originalNode != &currNode)
399     {
400         return &currNode;
401     }
402 
403     NodeStackGuard guard(mNodeStack);
404     mNodeStack = {&currNode, &guard.oldNodeStack};
405 
406     switch (currNodeType)
407     {
408         case NodeType::Unknown:
409             ASSERT(false);
410             return nullptr;
411         case NodeType::Symbol:
412             return &currNode;
413         case NodeType::ConstantUnion:
414             return &currNode;
415         case NodeType::FunctionPrototype:
416             return &currNode;
417         case NodeType::PreprocessorDirective:
418             return &currNode;
419         case NodeType::Unary:
420             return traverseUnaryChildren(*currNode.getAsUnaryNode());
421         case NodeType::Binary:
422             return traverseBinaryChildren(*currNode.getAsBinaryNode());
423         case NodeType::Ternary:
424             return traverseTernaryChildren(*currNode.getAsTernaryNode());
425         case NodeType::Swizzle:
426             return traverseSwizzleChildren(*currNode.getAsSwizzleNode());
427         case NodeType::IfElse:
428             return traverseIfElseChildren(*currNode.getAsIfElseNode());
429         case NodeType::Switch:
430             return traverseSwitchChildren(*currNode.getAsSwitchNode());
431         case NodeType::Case:
432             return traverseCaseChildren(*currNode.getAsCaseNode());
433         case NodeType::FunctionDefinition:
434             return traverseFunctionDefinitionChildren(*currNode.getAsFunctionDefinition());
435         case NodeType::Aggregate:
436             return traverseAggregateChildren(*currNode.getAsAggregate());
437         case NodeType::Block:
438             return traverseBlockChildren(*currNode.getAsBlock());
439         case NodeType::GlobalQualifierDeclaration:
440             return traverseGlobalQualifierDeclarationChildren(
441                 *currNode.getAsGlobalQualifierDeclarationNode());
442         case NodeType::Declaration:
443             return traverseDeclarationChildren(*currNode.getAsDeclarationNode());
444         case NodeType::Loop:
445             return traverseLoopChildren(*currNode.getAsLoopNode());
446         case NodeType::Branch:
447             return traverseBranchChildren(*currNode.getAsBranchNode());
448     }
449 }
450 
traversePost(NodeType currNodeType,const TIntermNode & originalNode,TIntermNode & currNode,VisitBits visit)451 PostResult TIntermRebuild::traversePost(NodeType currNodeType,
452                                         const TIntermNode &originalNode,
453                                         TIntermNode &currNode,
454                                         VisitBits visit)
455 {
456     if (!mPostVisit)
457     {
458         return currNode;
459     }
460 
461     if (!AnyBits(visit, VisitBits::Post))
462     {
463         return currNode;
464     }
465 
466     if (AnyBits(visit, VisitBits::PostRequiresSame) && &originalNode != &currNode)
467     {
468         return currNode;
469     }
470 
471     NodeStackGuard guard(mNodeStack);
472     mNodeStack = {&currNode, &guard.oldNodeStack};
473 
474     switch (currNodeType)
475     {
476         case NodeType::Unknown:
477             ASSERT(false);
478             return Fail();
479         case NodeType::Symbol:
480             return visitSymbolPost(*currNode.getAsSymbolNode());
481         case NodeType::ConstantUnion:
482             return visitConstantUnionPost(*currNode.getAsConstantUnion());
483         case NodeType::FunctionPrototype:
484             return visitFunctionPrototypePost(*currNode.getAsFunctionPrototypeNode());
485         case NodeType::PreprocessorDirective:
486             return visitPreprocessorDirectivePost(*currNode.getAsPreprocessorDirective());
487         case NodeType::Unary:
488             return visitUnaryPost(*currNode.getAsUnaryNode());
489         case NodeType::Binary:
490             return visitBinaryPost(*currNode.getAsBinaryNode());
491         case NodeType::Ternary:
492             return visitTernaryPost(*currNode.getAsTernaryNode());
493         case NodeType::Swizzle:
494             return visitSwizzlePost(*currNode.getAsSwizzleNode());
495         case NodeType::IfElse:
496             return visitIfElsePost(*currNode.getAsIfElseNode());
497         case NodeType::Switch:
498             return visitSwitchPost(*currNode.getAsSwitchNode());
499         case NodeType::Case:
500             return visitCasePost(*currNode.getAsCaseNode());
501         case NodeType::FunctionDefinition:
502             return visitFunctionDefinitionPost(*currNode.getAsFunctionDefinition());
503         case NodeType::Aggregate:
504             return visitAggregatePost(*currNode.getAsAggregate());
505         case NodeType::Block:
506             return visitBlockPost(*currNode.getAsBlock());
507         case NodeType::GlobalQualifierDeclaration:
508             return visitGlobalQualifierDeclarationPost(
509                 *currNode.getAsGlobalQualifierDeclarationNode());
510         case NodeType::Declaration:
511             return visitDeclarationPost(*currNode.getAsDeclarationNode());
512         case NodeType::Loop:
513             return visitLoopPost(*currNode.getAsLoopNode());
514         case NodeType::Branch:
515             return visitBranchPost(*currNode.getAsBranchNode());
516     }
517 }
518 
519 ////////////////////////////////////////////////////////////////////////////////
520 
traverseAggregateChildren(TIntermAggregate & node)521 TIntermNode *TIntermRebuild::traverseAggregateChildren(TIntermAggregate &node)
522 {
523     if (traverseAggregateBaseChildren(node))
524     {
525         return &node;
526     }
527     return nullptr;
528 }
529 
traverseBlockChildren(TIntermBlock & node)530 TIntermNode *TIntermRebuild::traverseBlockChildren(TIntermBlock &node)
531 {
532     if (traverseAggregateBaseChildren(node))
533     {
534         return &node;
535     }
536     return nullptr;
537 }
538 
traverseDeclarationChildren(TIntermDeclaration & node)539 TIntermNode *TIntermRebuild::traverseDeclarationChildren(TIntermDeclaration &node)
540 {
541     if (traverseAggregateBaseChildren(node))
542     {
543         return &node;
544     }
545     return nullptr;
546 }
547 
traverseSwizzleChildren(TIntermSwizzle & node)548 TIntermNode *TIntermRebuild::traverseSwizzleChildren(TIntermSwizzle &node)
549 {
550     auto *const operand = node.getOperand();
551     ASSERT(operand);
552 
553     auto *newOperand = traverseAnyAs<TIntermTyped>(*operand);
554     GUARD(newOperand);
555 
556     if (newOperand != operand)
557     {
558         return new TIntermSwizzle(newOperand, node.getSwizzleOffsets());
559     }
560 
561     return &node;
562 }
563 
traverseBinaryChildren(TIntermBinary & node)564 TIntermNode *TIntermRebuild::traverseBinaryChildren(TIntermBinary &node)
565 {
566     auto *const left = node.getLeft();
567     ASSERT(left);
568     auto *const right = node.getRight();
569     ASSERT(right);
570 
571     auto *const newLeft = traverseAnyAs<TIntermTyped>(*left);
572     GUARD(newLeft);
573     auto *const newRight = traverseAnyAs<TIntermTyped>(*right);
574     GUARD(newRight);
575 
576     if (newLeft != left || newRight != right)
577     {
578         TOperator op = node.getOp();
579         switch (op)
580         {
581             case TOperator::EOpIndexDirectStruct:
582             {
583                 if (newLeft->getType().getInterfaceBlock())
584                 {
585                     op = TOperator::EOpIndexDirectInterfaceBlock;
586                 }
587             }
588             break;
589 
590             case TOperator::EOpIndexDirectInterfaceBlock:
591             {
592                 if (newLeft->getType().getStruct())
593                 {
594                     op = TOperator::EOpIndexDirectStruct;
595                 }
596             }
597             break;
598 
599             case TOperator::EOpComma:
600                 return TIntermBinary::CreateComma(newLeft, newRight, mCompiler.getShaderVersion());
601 
602             default:
603                 break;
604         }
605 
606         return new TIntermBinary(op, newLeft, newRight);
607     }
608 
609     return &node;
610 }
611 
traverseUnaryChildren(TIntermUnary & node)612 TIntermNode *TIntermRebuild::traverseUnaryChildren(TIntermUnary &node)
613 {
614     auto *const operand = node.getOperand();
615     ASSERT(operand);
616 
617     auto *const newOperand = traverseAnyAs<TIntermTyped>(*operand);
618     GUARD(newOperand);
619 
620     if (newOperand != operand)
621     {
622         return new TIntermUnary(node.getOp(), newOperand, node.getFunction());
623     }
624 
625     return &node;
626 }
627 
traverseTernaryChildren(TIntermTernary & node)628 TIntermNode *TIntermRebuild::traverseTernaryChildren(TIntermTernary &node)
629 {
630     auto *const cond = node.getCondition();
631     ASSERT(cond);
632     auto *const true_ = node.getTrueExpression();
633     ASSERT(true_);
634     auto *const false_ = node.getFalseExpression();
635     ASSERT(false_);
636 
637     auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
638     GUARD(newCond);
639     auto *const newTrue = traverseAnyAs<TIntermTyped>(*true_);
640     GUARD(newTrue);
641     auto *const newFalse = traverseAnyAs<TIntermTyped>(*false_);
642     GUARD(newFalse);
643 
644     if (newCond != cond || newTrue != true_ || newFalse != false_)
645     {
646         return new TIntermTernary(newCond, newTrue, newFalse);
647     }
648 
649     return &node;
650 }
651 
traverseIfElseChildren(TIntermIfElse & node)652 TIntermNode *TIntermRebuild::traverseIfElseChildren(TIntermIfElse &node)
653 {
654     auto *const cond = node.getCondition();
655     ASSERT(cond);
656     auto *const true_  = node.getTrueBlock();
657     auto *const false_ = node.getFalseBlock();
658 
659     auto *const newCond = traverseAnyAs<TIntermTyped>(*cond);
660     GUARD(newCond);
661     TIntermBlock *newTrue = nullptr;
662     if (true_)
663     {
664         GUARD(traverseAnyAs(*true_, newTrue));
665     }
666     TIntermBlock *newFalse = nullptr;
667     if (false_)
668     {
669         GUARD(traverseAnyAs(*false_, newFalse));
670     }
671 
672     if (newCond != cond || newTrue != true_ || newFalse != false_)
673     {
674         return new TIntermIfElse(newCond, newTrue, newFalse);
675     }
676 
677     return &node;
678 }
679 
traverseSwitchChildren(TIntermSwitch & node)680 TIntermNode *TIntermRebuild::traverseSwitchChildren(TIntermSwitch &node)
681 {
682     auto *const init = node.getInit();
683     ASSERT(init);
684     auto *const stmts = node.getStatementList();
685     ASSERT(stmts);
686 
687     auto *const newInit = traverseAnyAs<TIntermTyped>(*init);
688     GUARD(newInit);
689     auto *const newStmts = traverseAnyAs<TIntermBlock>(*stmts);
690     GUARD(newStmts);
691 
692     if (newInit != init || newStmts != stmts)
693     {
694         return new TIntermSwitch(newInit, newStmts);
695     }
696 
697     return &node;
698 }
699 
traverseCaseChildren(TIntermCase & node)700 TIntermNode *TIntermRebuild::traverseCaseChildren(TIntermCase &node)
701 {
702     auto *const cond = node.getCondition();
703 
704     TIntermTyped *newCond = nullptr;
705     if (cond)
706     {
707         GUARD(traverseAnyAs(*cond, newCond));
708     }
709 
710     if (newCond != cond)
711     {
712         return new TIntermCase(newCond);
713     }
714 
715     return &node;
716 }
717 
traverseFunctionDefinitionChildren(TIntermFunctionDefinition & node)718 TIntermNode *TIntermRebuild::traverseFunctionDefinitionChildren(TIntermFunctionDefinition &node)
719 {
720     GUARD(!mParentFunc);  // Function definitions cannot be nested.
721     mParentFunc = node.getFunction();
722     struct OnExit
723     {
724         const TFunction *&parentFunc;
725         OnExit(const TFunction *&parentFunc) : parentFunc(parentFunc) {}
726         ~OnExit() { parentFunc = nullptr; }
727     } onExit(mParentFunc);
728 
729     auto *const proto = node.getFunctionPrototype();
730     ASSERT(proto);
731     auto *const body = node.getBody();
732     ASSERT(body);
733 
734     auto *const newProto = traverseAnyAs<TIntermFunctionPrototype>(*proto);
735     GUARD(newProto);
736     auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
737     GUARD(newBody);
738 
739     if (newProto != proto || newBody != body)
740     {
741         return new TIntermFunctionDefinition(newProto, newBody);
742     }
743 
744     return &node;
745 }
746 
traverseGlobalQualifierDeclarationChildren(TIntermGlobalQualifierDeclaration & node)747 TIntermNode *TIntermRebuild::traverseGlobalQualifierDeclarationChildren(
748     TIntermGlobalQualifierDeclaration &node)
749 {
750     auto *const symbol = node.getSymbol();
751     ASSERT(symbol);
752 
753     auto *const newSymbol = traverseAnyAs<TIntermSymbol>(*symbol);
754     GUARD(newSymbol);
755 
756     if (newSymbol != symbol)
757     {
758         return new TIntermGlobalQualifierDeclaration(newSymbol, node.isPrecise(), node.getLine());
759     }
760 
761     return &node;
762 }
763 
traverseLoopChildren(TIntermLoop & node)764 TIntermNode *TIntermRebuild::traverseLoopChildren(TIntermLoop &node)
765 {
766     const TLoopType loopType = node.getType();
767 
768     auto *const init = node.getInit();
769     auto *const cond = node.getCondition();
770     auto *const expr = node.getExpression();
771     auto *const body = node.getBody();
772     ASSERT(body);
773 
774 #if defined(ANGLE_ENABLE_ASSERTS)
775     switch (loopType)
776     {
777         case TLoopType::ELoopFor:
778             break;
779         case TLoopType::ELoopWhile:
780         case TLoopType::ELoopDoWhile:
781             ASSERT(cond);
782             ASSERT(!init && !expr);
783             break;
784     }
785 #endif
786 
787     auto *const newBody = traverseAnyAs<TIntermBlock>(*body);
788     GUARD(newBody);
789     TIntermNode *newInit = nullptr;
790     if (init)
791     {
792         GUARD(traverseAnyAs(*init, newInit));
793     }
794     TIntermTyped *newCond = nullptr;
795     if (cond)
796     {
797         GUARD(traverseAnyAs(*cond, newCond));
798     }
799     TIntermTyped *newExpr = nullptr;
800     if (expr)
801     {
802         GUARD(traverseAnyAs(*expr, newExpr));
803     }
804 
805     if (newInit != init || newCond != cond || newExpr != expr || newBody != body)
806     {
807         switch (loopType)
808         {
809             case TLoopType::ELoopFor:
810                 GUARD(newBody);
811                 break;
812             case TLoopType::ELoopWhile:
813             case TLoopType::ELoopDoWhile:
814                 GUARD(newCond && newBody);
815                 GUARD(!newInit && !newExpr);
816                 break;
817         }
818         return new TIntermLoop(loopType, newInit, newCond, newExpr, newBody);
819     }
820 
821     return &node;
822 }
823 
traverseBranchChildren(TIntermBranch & node)824 TIntermNode *TIntermRebuild::traverseBranchChildren(TIntermBranch &node)
825 {
826     auto *const expr = node.getExpression();
827 
828     TIntermTyped *newExpr = nullptr;
829     if (expr)
830     {
831         GUARD(traverseAnyAs<TIntermTyped>(*expr, newExpr));
832     }
833 
834     if (newExpr != expr)
835     {
836         return new TIntermBranch(node.getFlowOp(), newExpr);
837     }
838 
839     return &node;
840 }
841 
842 ////////////////////////////////////////////////////////////////////////////////
843 
visitSymbolPre(TIntermSymbol & node)844 PreResult TIntermRebuild::visitSymbolPre(TIntermSymbol &node)
845 {
846     return {node, VisitBits::Both};
847 }
848 
visitConstantUnionPre(TIntermConstantUnion & node)849 PreResult TIntermRebuild::visitConstantUnionPre(TIntermConstantUnion &node)
850 {
851     return {node, VisitBits::Both};
852 }
853 
visitFunctionPrototypePre(TIntermFunctionPrototype & node)854 PreResult TIntermRebuild::visitFunctionPrototypePre(TIntermFunctionPrototype &node)
855 {
856     return {node, VisitBits::Both};
857 }
858 
visitPreprocessorDirectivePre(TIntermPreprocessorDirective & node)859 PreResult TIntermRebuild::visitPreprocessorDirectivePre(TIntermPreprocessorDirective &node)
860 {
861     return {node, VisitBits::Both};
862 }
863 
visitUnaryPre(TIntermUnary & node)864 PreResult TIntermRebuild::visitUnaryPre(TIntermUnary &node)
865 {
866     return {node, VisitBits::Both};
867 }
868 
visitBinaryPre(TIntermBinary & node)869 PreResult TIntermRebuild::visitBinaryPre(TIntermBinary &node)
870 {
871     return {node, VisitBits::Both};
872 }
873 
visitTernaryPre(TIntermTernary & node)874 PreResult TIntermRebuild::visitTernaryPre(TIntermTernary &node)
875 {
876     return {node, VisitBits::Both};
877 }
878 
visitSwizzlePre(TIntermSwizzle & node)879 PreResult TIntermRebuild::visitSwizzlePre(TIntermSwizzle &node)
880 {
881     return {node, VisitBits::Both};
882 }
883 
visitIfElsePre(TIntermIfElse & node)884 PreResult TIntermRebuild::visitIfElsePre(TIntermIfElse &node)
885 {
886     return {node, VisitBits::Both};
887 }
888 
visitSwitchPre(TIntermSwitch & node)889 PreResult TIntermRebuild::visitSwitchPre(TIntermSwitch &node)
890 {
891     return {node, VisitBits::Both};
892 }
893 
visitCasePre(TIntermCase & node)894 PreResult TIntermRebuild::visitCasePre(TIntermCase &node)
895 {
896     return {node, VisitBits::Both};
897 }
898 
visitLoopPre(TIntermLoop & node)899 PreResult TIntermRebuild::visitLoopPre(TIntermLoop &node)
900 {
901     return {node, VisitBits::Both};
902 }
903 
visitBranchPre(TIntermBranch & node)904 PreResult TIntermRebuild::visitBranchPre(TIntermBranch &node)
905 {
906     return {node, VisitBits::Both};
907 }
908 
visitDeclarationPre(TIntermDeclaration & node)909 PreResult TIntermRebuild::visitDeclarationPre(TIntermDeclaration &node)
910 {
911     return {node, VisitBits::Both};
912 }
913 
visitBlockPre(TIntermBlock & node)914 PreResult TIntermRebuild::visitBlockPre(TIntermBlock &node)
915 {
916     return {node, VisitBits::Both};
917 }
918 
visitAggregatePre(TIntermAggregate & node)919 PreResult TIntermRebuild::visitAggregatePre(TIntermAggregate &node)
920 {
921     return {node, VisitBits::Both};
922 }
923 
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)924 PreResult TIntermRebuild::visitFunctionDefinitionPre(TIntermFunctionDefinition &node)
925 {
926     return {node, VisitBits::Both};
927 }
928 
visitGlobalQualifierDeclarationPre(TIntermGlobalQualifierDeclaration & node)929 PreResult TIntermRebuild::visitGlobalQualifierDeclarationPre(
930     TIntermGlobalQualifierDeclaration &node)
931 {
932     return {node, VisitBits::Both};
933 }
934 
935 ////////////////////////////////////////////////////////////////////////////////
936 
visitSymbolPost(TIntermSymbol & node)937 PostResult TIntermRebuild::visitSymbolPost(TIntermSymbol &node)
938 {
939     return node;
940 }
941 
visitConstantUnionPost(TIntermConstantUnion & node)942 PostResult TIntermRebuild::visitConstantUnionPost(TIntermConstantUnion &node)
943 {
944     return node;
945 }
946 
visitFunctionPrototypePost(TIntermFunctionPrototype & node)947 PostResult TIntermRebuild::visitFunctionPrototypePost(TIntermFunctionPrototype &node)
948 {
949     return node;
950 }
951 
visitPreprocessorDirectivePost(TIntermPreprocessorDirective & node)952 PostResult TIntermRebuild::visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node)
953 {
954     return node;
955 }
956 
visitUnaryPost(TIntermUnary & node)957 PostResult TIntermRebuild::visitUnaryPost(TIntermUnary &node)
958 {
959     return node;
960 }
961 
visitBinaryPost(TIntermBinary & node)962 PostResult TIntermRebuild::visitBinaryPost(TIntermBinary &node)
963 {
964     return node;
965 }
966 
visitTernaryPost(TIntermTernary & node)967 PostResult TIntermRebuild::visitTernaryPost(TIntermTernary &node)
968 {
969     return node;
970 }
971 
visitSwizzlePost(TIntermSwizzle & node)972 PostResult TIntermRebuild::visitSwizzlePost(TIntermSwizzle &node)
973 {
974     return node;
975 }
976 
visitIfElsePost(TIntermIfElse & node)977 PostResult TIntermRebuild::visitIfElsePost(TIntermIfElse &node)
978 {
979     return node;
980 }
981 
visitSwitchPost(TIntermSwitch & node)982 PostResult TIntermRebuild::visitSwitchPost(TIntermSwitch &node)
983 {
984     return node;
985 }
986 
visitCasePost(TIntermCase & node)987 PostResult TIntermRebuild::visitCasePost(TIntermCase &node)
988 {
989     return node;
990 }
991 
visitLoopPost(TIntermLoop & node)992 PostResult TIntermRebuild::visitLoopPost(TIntermLoop &node)
993 {
994     return node;
995 }
996 
visitBranchPost(TIntermBranch & node)997 PostResult TIntermRebuild::visitBranchPost(TIntermBranch &node)
998 {
999     return node;
1000 }
1001 
visitDeclarationPost(TIntermDeclaration & node)1002 PostResult TIntermRebuild::visitDeclarationPost(TIntermDeclaration &node)
1003 {
1004     return node;
1005 }
1006 
visitBlockPost(TIntermBlock & node)1007 PostResult TIntermRebuild::visitBlockPost(TIntermBlock &node)
1008 {
1009     return node;
1010 }
1011 
visitAggregatePost(TIntermAggregate & node)1012 PostResult TIntermRebuild::visitAggregatePost(TIntermAggregate &node)
1013 {
1014     return node;
1015 }
1016 
visitFunctionDefinitionPost(TIntermFunctionDefinition & node)1017 PostResult TIntermRebuild::visitFunctionDefinitionPost(TIntermFunctionDefinition &node)
1018 {
1019     return node;
1020 }
1021 
visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration & node)1022 PostResult TIntermRebuild::visitGlobalQualifierDeclarationPost(
1023     TIntermGlobalQualifierDeclaration &node)
1024 {
1025     return node;
1026 }
1027 
1028 }  // namespace sh
1029