1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #ifndef COMPILER_TRANSLATOR_TRANSLATORMETALDIRECT_INTERMREBUILD_H_
8 #define COMPILER_TRANSLATOR_TRANSLATORMETALDIRECT_INTERMREBUILD_H_
9 
10 #include "compiler/translator/TranslatorMetalDirect/NodeType.h"
11 #include "compiler/translator/tree_util/IntermTraverse.h"
12 
13 namespace sh
14 {
15 
16 // Walks the tree to rebuild nodes.
17 // This class is intended to be derived with overridden visitXXX functions.
18 //
19 // Each visitXXX function that does not have a Visit parameter simply has the visitor called
20 // exactly once, regardless of (preVisit) or (postVisit) values.
21 
22 // Each visitXXX function that has a Visit parameter behaves as follows:
23 //    * If (preVisit):
24 //      - The node is visited before children are traversed.
25 //      - The returned value is used to replace the visited node. The returned value may be the same
26 //        as the original node.
27 //      - If multiple nodes are returned, children and post visits of the returned nodes are not
28 //        preformed, even if it is a singleton collection.
29 //    * If (childVisit)
30 //      - If any new children are returned, the node is automatically rebuilt with the new children
31 //        before post visit.
32 //      - Depending on the type of the node, null children may be discarded.
33 //      - Ill-typed children cause rebuild errors. Ill-typed means the node to automatically rebuild
34 //        cannot accept a child of a certain type as input to its constructor.
35 //      - Only instances of TIntermAggregateBase can accept Multi results for any of its children.
36 //        If supplied, the nodes are spliced children at the spot of the original child.
37 //    * If (postVisit)
38 //      - The node is visited after any children are traversed.
39 //      - Only after such a rebuild (or lack thereof), the post-visit is performed.
40 //
41 // Nodes in visit functions are allowed to be modified in place, including TIntermAggregateBase
42 // child sequences.
43 //
44 // The default implementations of all the visitXXX functions support full pre and post traversal
45 // without modifying the visited nodes.
46 //
47 class TIntermRebuild : angle::NonCopyable
48 {
49 
50     enum class Action
51     {
52         ReplaceSingle,
53         ReplaceMulti,
54         Drop,
55         Fail,
56     };
57 
58   public:
59     struct Fail
60     {};
61 
62     enum VisitBits : size_t
63     {
64         // No bits are set.
65         Empty = 0u,
66 
67         // Allow visit of returned node's children.
68         Children = 1u << 0u,
69 
70         // Allow post visit of returned node.
71         Post = 1u << 1u,
72 
73         // If (Children) bit, only visit if the returned node is the same as the original node.
74         ChildrenRequiresSame = 1u << 2u,
75 
76         // If (Post) bit, only visit if the returned node is the same as the original node.
77         PostRequiresSame = 1u << 3u,
78 
79         RequireSame  = ChildrenRequiresSame | PostRequiresSame,
80         Neither      = Empty,
81         Both         = Children | Post,
82         BothWhenSame = Both | RequireSame,
83     };
84 
85   private:
86     struct NodeStackGuard;
87 
88     template <typename T>
89     struct ConsList
90     {
91         T value;
92         ConsList<T> *tail;
93     };
94 
95     class BaseResult
96     {
97         BaseResult(const BaseResult &) = delete;
98         BaseResult &operator=(const BaseResult &) = delete;
99 
100       public:
101         BaseResult(BaseResult &&other) = default;
102         BaseResult(BaseResult &other);  // For subclass move constructor impls
103         BaseResult(TIntermNode &node, VisitBits visit);
104         BaseResult(TIntermNode *node, VisitBits visit);
105         BaseResult(nullptr_t);
106         BaseResult(Fail);
107         BaseResult(std::vector<TIntermNode *> &&nodes);
108 
109         void moveAssignImpl(BaseResult &other);  // For subclass move assign impls
110 
111         static BaseResult Multi(std::vector<TIntermNode *> &&nodes);
112 
113         template <typename Iter>
Multi(Iter nodesBegin,Iter nodesEnd)114         static BaseResult Multi(Iter nodesBegin, Iter nodesEnd)
115         {
116             std::vector<TIntermNode *> nodes;
117             for (Iter nodesCurr = nodesBegin; nodesCurr != nodesEnd; ++nodesCurr)
118             {
119                 nodes.push_back(*nodesCurr);
120             }
121             return std::move(nodes);
122         }
123 
124         bool isFail() const;
125         bool isDrop() const;
126         TIntermNode *single() const;
127         const std::vector<TIntermNode *> *multi() const;
128 
129       public:
130         Action mAction;
131         VisitBits mVisit;
132         TIntermNode *mSingle;
133         std::vector<TIntermNode *> mMulti;
134     };
135 
136   public:
137     class PreResult : private BaseResult
138     {
139         friend class TIntermRebuild;
140 
141       public:
142         PreResult(PreResult &&other);
143         PreResult(TIntermNode &node, VisitBits visit = VisitBits::BothWhenSame);
144         PreResult(TIntermNode *node, VisitBits visit = VisitBits::BothWhenSame);
145         PreResult(nullptr_t);  // Used to drop a node.
146         PreResult(Fail);       // Used to signal failure.
147 
148         void operator=(PreResult &&other);
149 
Multi(std::vector<TIntermNode * > && nodes)150         static PreResult Multi(std::vector<TIntermNode *> &&nodes)
151         {
152             return BaseResult::Multi(std::move(nodes));
153         }
154 
155         template <typename Iter>
Multi(Iter nodesBegin,Iter nodesEnd)156         static PreResult Multi(Iter nodesBegin, Iter nodesEnd)
157         {
158             return BaseResult::Multi(nodesBegin, nodesEnd);
159         }
160 
161         using BaseResult::isDrop;
162         using BaseResult::isFail;
163         using BaseResult::multi;
164         using BaseResult::single;
165 
166       private:
167         PreResult(BaseResult &&other);
168     };
169 
170     class PostResult : private BaseResult
171     {
172         friend class TIntermRebuild;
173 
174       public:
175         PostResult(PostResult &&other);
176         PostResult(TIntermNode &node);
177         PostResult(TIntermNode *node);
178         PostResult(nullptr_t);  // Used to drop a node
179         PostResult(Fail);       // Used to signal failure.
180 
181         void operator=(PostResult &&other);
182 
Multi(std::vector<TIntermNode * > && nodes)183         static PostResult Multi(std::vector<TIntermNode *> &&nodes)
184         {
185             return BaseResult::Multi(std::move(nodes));
186         }
187 
188         template <typename Iter>
Multi(Iter nodesBegin,Iter nodesEnd)189         static PostResult Multi(Iter nodesBegin, Iter nodesEnd)
190         {
191             return BaseResult::Multi(nodesBegin, nodesEnd);
192         }
193 
194         using BaseResult::isDrop;
195         using BaseResult::isFail;
196         using BaseResult::multi;
197         using BaseResult::single;
198 
199       private:
200         PostResult(BaseResult &&other);
201     };
202 
203   public:
204     TIntermRebuild(TCompiler &compiler, bool preVisit, bool postVisit);
205 
206     virtual ~TIntermRebuild();
207 
208     // Rebuilds the tree starting at the provided root. If a new node would be returned for the
209     // root, the root node's children become that of the new node instead. Returns false if failure
210     // occurred.
211     ANGLE_NO_DISCARD bool rebuildRoot(TIntermBlock &root);
212 
213   protected:
214     virtual PreResult visitSymbolPre(TIntermSymbol &node);
215     virtual PreResult visitConstantUnionPre(TIntermConstantUnion &node);
216     virtual PreResult visitFunctionPrototypePre(TIntermFunctionPrototype &node);
217     virtual PreResult visitPreprocessorDirectivePre(TIntermPreprocessorDirective &node);
218     virtual PreResult visitUnaryPre(TIntermUnary &node);
219     virtual PreResult visitBinaryPre(TIntermBinary &node);
220     virtual PreResult visitTernaryPre(TIntermTernary &node);
221     virtual PreResult visitSwizzlePre(TIntermSwizzle &node);
222     virtual PreResult visitIfElsePre(TIntermIfElse &node);
223     virtual PreResult visitSwitchPre(TIntermSwitch &node);
224     virtual PreResult visitCasePre(TIntermCase &node);
225     virtual PreResult visitLoopPre(TIntermLoop &node);
226     virtual PreResult visitBranchPre(TIntermBranch &node);
227     virtual PreResult visitDeclarationPre(TIntermDeclaration &node);
228     virtual PreResult visitBlockPre(TIntermBlock &node);
229     virtual PreResult visitAggregatePre(TIntermAggregate &node);
230     virtual PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node);
231     virtual PreResult visitGlobalQualifierDeclarationPre(TIntermGlobalQualifierDeclaration &node);
232 
233     virtual PostResult visitSymbolPost(TIntermSymbol &node);
234     virtual PostResult visitConstantUnionPost(TIntermConstantUnion &node);
235     virtual PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &node);
236     virtual PostResult visitPreprocessorDirectivePost(TIntermPreprocessorDirective &node);
237     virtual PostResult visitUnaryPost(TIntermUnary &node);
238     virtual PostResult visitBinaryPost(TIntermBinary &node);
239     virtual PostResult visitTernaryPost(TIntermTernary &node);
240     virtual PostResult visitSwizzlePost(TIntermSwizzle &node);
241     virtual PostResult visitIfElsePost(TIntermIfElse &node);
242     virtual PostResult visitSwitchPost(TIntermSwitch &node);
243     virtual PostResult visitCasePost(TIntermCase &node);
244     virtual PostResult visitLoopPost(TIntermLoop &node);
245     virtual PostResult visitBranchPost(TIntermBranch &node);
246     virtual PostResult visitDeclarationPost(TIntermDeclaration &node);
247     virtual PostResult visitBlockPost(TIntermBlock &node);
248     virtual PostResult visitAggregatePost(TIntermAggregate &node);
249     virtual PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &node);
250     virtual PostResult visitGlobalQualifierDeclarationPost(TIntermGlobalQualifierDeclaration &node);
251 
252     // Can be used to rebuild a specific node during a traversal. Useful for fine control of
253     // rebuilding a node's children.
254     ANGLE_NO_DISCARD PostResult rebuild(TIntermNode &node);
255 
256     // Rebuilds the provided node in place. If a new node would be returned, the old node's children
257     // become that of the new node instead. Returns false if failure occurred.
258     ANGLE_NO_DISCARD bool rebuildInPlace(TIntermAggregate &node);
259 
260     // Rebuilds the provided node in place. If a new node would be returned, the old node's children
261     // become that of the new node instead. Returns false if failure occurred.
262     ANGLE_NO_DISCARD bool rebuildInPlace(TIntermBlock &node);
263 
264     // Rebuilds the provided node in place. If a new node would be returned, the old node's children
265     // become that of the new node instead. Returns false if failure occurred.
266     ANGLE_NO_DISCARD bool rebuildInPlace(TIntermDeclaration &node);
267 
268     // If currently at or below a function declaration body, this returns the function that encloses
269     // the currently visited node. (This returns null if at a function declaration node.)
270     const TFunction *getParentFunction() const;
271 
272     TIntermNode *getParentNode(size_t offset = 0) const;
273 
274   private:
275     template <typename Node>
276     ANGLE_NO_DISCARD bool rebuildInPlaceImpl(Node &node);
277 
278     PostResult traverseAny(TIntermNode &node);
279 
280     template <typename Node>
281     Node *traverseAnyAs(TIntermNode &node);
282 
283     template <typename Node>
284     bool traverseAnyAs(TIntermNode &node, Node *&out);
285 
286     PreResult traversePre(TIntermNode &originalNode);
287     TIntermNode *traverseChildren(NodeType currNodeType,
288                                   const TIntermNode &originalNode,
289                                   TIntermNode &currNode,
290                                   VisitBits visit);
291     PostResult traversePost(NodeType nodeType,
292                             const TIntermNode &originalNode,
293                             TIntermNode &currNode,
294                             VisitBits visit);
295 
296     bool traverseAggregateBaseChildren(TIntermAggregateBase &node);
297 
298     TIntermNode *traverseUnaryChildren(TIntermUnary &node);
299     TIntermNode *traverseBinaryChildren(TIntermBinary &node);
300     TIntermNode *traverseTernaryChildren(TIntermTernary &node);
301     TIntermNode *traverseSwizzleChildren(TIntermSwizzle &node);
302     TIntermNode *traverseIfElseChildren(TIntermIfElse &node);
303     TIntermNode *traverseSwitchChildren(TIntermSwitch &node);
304     TIntermNode *traverseCaseChildren(TIntermCase &node);
305     TIntermNode *traverseLoopChildren(TIntermLoop &node);
306     TIntermNode *traverseBranchChildren(TIntermBranch &node);
307     TIntermNode *traverseDeclarationChildren(TIntermDeclaration &node);
308     TIntermNode *traverseBlockChildren(TIntermBlock &node);
309     TIntermNode *traverseAggregateChildren(TIntermAggregate &node);
310     TIntermNode *traverseFunctionDefinitionChildren(TIntermFunctionDefinition &node);
311     TIntermNode *traverseGlobalQualifierDeclarationChildren(
312         TIntermGlobalQualifierDeclaration &node);
313 
314   protected:
315     TCompiler &mCompiler;
316     TSymbolTable &mSymbolTable;
317     const TFunction *mParentFunc = nullptr;
318     GetNodeType getNodeType;
319 
320   private:
321     ConsList<TIntermNode *> mNodeStack{nullptr, nullptr};
322     bool mPreVisit;
323     bool mPostVisit;
324 };
325 
326 }  // namespace sh
327 
328 #endif  // COMPILER_TRANSLATOR_TRANSLATORMETALDIRECT_INTERMREBUILD_H_
329