1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #ifndef SKSL_ASTNODE
9 #define SKSL_ASTNODE
10 
11 #include "include/private/SkSLModifiers.h"
12 #include "include/private/SkSLString.h"
13 #include "src/sksl/SkSLLexer.h"
14 #include "src/sksl/SkSLOperators.h"
15 
16 #include <algorithm>
17 #include <vector>
18 
19 namespace SkSL {
20 
21 /**
22  * Represents a node in the abstract syntax tree (AST). The AST is based directly on the parse tree;
23  * it is a parsed-but-not-yet-analyzed version of the program.
24  */
25 struct ASTNode {
26     class ID {
27     public:
InvalidASTNode28         static ID Invalid() {
29             return ID();
30         }
31 
32         bool operator==(const ID& other) {
33             return fValue == other.fValue;
34         }
35 
36         bool operator!=(const ID& other) {
37             return fValue != other.fValue;
38         }
39 
40         operator bool() const { return fValue >= 0; }
41 
42     private:
IDASTNode43         ID()
44             : fValue(-1) {}
45 
IDASTNode46         ID(int value)
47             : fValue(value) {}
48 
49         int fValue;
50 
51         friend struct ASTFile;
52         friend struct ASTNode;
53         friend class Parser;
54     };
55 
56     enum class Kind {
57         // data: operator, children: left, right
58         kBinary,
59         // children: statements
60         kBlock,
61         // data: value(bool)
62         kBool,
63         kBreak,
64         // children: target, arg1, arg2...
65         kCall,
66         kContinue,
67         kDiscard,
68         // children: statement, test
69         kDo,
70         // data: name(StringFragment), children: enumCases
71         kEnum,
72         // data: name(StringFragment), children: value?
73         kEnumCase,
74         // data: name(StringFragment)
75         kExtension,
76         // data: field(StringFragment), children: base
77         kField,
78         // children: declarations
79         kFile,
80         // data: value(float)
81         kFloat,
82         // children: init, test, next, statement
83         kFor,
84         // data: FunctionData, children: returnType, parameters, statement?
85         kFunction,
86         // data: name(StringFragment)
87         kIdentifier,
88         // children: base, index?
89         kIndex,
90         // data: isStatic(bool), children: test, ifTrue, ifFalse?
91         kIf,
92         // value(data): int
93         kInt,
94         // data: InterfaceBlockData, children: declaration1, declaration2, ..., size1, size2, ...
95         kInterfaceBlock,
96         // data: Modifiers
97         kModifiers,
98         kNull,
99         // data: ParameterData, children: type, arraySize1, arraySize2, ..., value?
100         kParameter,
101         // data: operator, children: operand
102         kPostfix,
103         // data: operator, children: operand
104         kPrefix,
105         // children: value
106         kReturn,
107         // data: field(StringFragment), children: base
108         kScope,
109         // ...
110         kSection,
111         // children: value, statement 1, statement 2...
112         kSwitchCase,
113         // children: value, case 1, case 2...
114         kSwitch,
115         // children: test, ifTrue, ifFalse
116         kTernary,
117         // data: name(StringFragment), children: sizes
118         kType,
119         // data: VarData, children: arraySize1, arraySize2, ..., value?
120         kVarDeclaration,
121         // children: modifiers, type, varDeclaration1, varDeclaration2, ...
122         kVarDeclarations,
123         // children: test, statement
124         kWhile,
125     };
126 
127     class iterator {
128     public:
129         iterator operator++() {
130             SkASSERT(fID);
131             fID = (**this).fNext;
132             return *this;
133         }
134 
135         iterator operator++(int) {
136             SkASSERT(fID);
137             iterator old = *this;
138             fID = (**this).fNext;
139             return old;
140         }
141 
142         iterator operator+=(int count) {
143             SkASSERT(count >= 0);
144             for (; count > 0; --count) {
145                 ++(*this);
146             }
147             return *this;
148         }
149 
150         iterator operator+(int count) {
151             iterator result(*this);
152             return result += count;
153         }
154 
155         bool operator==(const iterator& other) const {
156             return fID == other.fID;
157         }
158 
159         bool operator!=(const iterator& other) const {
160             return fID != other.fID;
161         }
162 
163         ASTNode& operator*() {
164             SkASSERT(fID);
165             return (*fNodes)[fID.fValue];
166         }
167 
168         ASTNode* operator->() {
169             SkASSERT(fID);
170             return &(*fNodes)[fID.fValue];
171         }
172 
173     private:
iteratorASTNode174         iterator(std::vector<ASTNode>* nodes, ID id)
175             : fNodes(nodes)
176             , fID(id) {}
177 
178         std::vector<ASTNode>* fNodes;
179 
180         ID fID;
181 
182         friend struct ASTNode;
183     };
184 
185     struct ParameterData {
ParameterDataASTNode::ParameterData186         ParameterData() {}
187 
ParameterDataASTNode::ParameterData188         ParameterData(Modifiers modifiers, StringFragment name, bool isArray)
189             : fModifiers(modifiers)
190             , fName(name)
191             , fIsArray(isArray) {}
192 
193         Modifiers fModifiers;
194         StringFragment fName;
195         bool fIsArray;
196     };
197 
198     struct VarData {
VarDataASTNode::VarData199         VarData() {}
200 
VarDataASTNode::VarData201         VarData(StringFragment name, bool isArray)
202             : fName(name)
203             , fIsArray(isArray) {}
204 
205         StringFragment fName;
206         bool fIsArray;
207     };
208 
209     struct FunctionData {
FunctionDataASTNode::FunctionData210         FunctionData() {}
211 
FunctionDataASTNode::FunctionData212         FunctionData(Modifiers modifiers, StringFragment name, size_t parameterCount)
213             : fModifiers(modifiers)
214             , fName(name)
215             , fParameterCount(parameterCount) {}
216 
217         Modifiers fModifiers;
218         StringFragment fName;
219         size_t fParameterCount;
220     };
221 
222     struct InterfaceBlockData {
InterfaceBlockDataASTNode::InterfaceBlockData223         InterfaceBlockData() {}
224 
InterfaceBlockDataASTNode::InterfaceBlockData225         InterfaceBlockData(Modifiers modifiers, StringFragment typeName, size_t declarationCount,
226                            StringFragment instanceName, bool isArray)
227             : fModifiers(modifiers)
228             , fTypeName(typeName)
229             , fDeclarationCount(declarationCount)
230             , fInstanceName(instanceName)
231             , fIsArray(isArray) {}
232 
233         Modifiers fModifiers;
234         StringFragment fTypeName;
235         size_t fDeclarationCount;
236         StringFragment fInstanceName;
237         bool fIsArray;
238     };
239 
240     struct SectionData {
SectionDataASTNode::SectionData241         SectionData() {}
242 
SectionDataASTNode::SectionData243         SectionData(StringFragment name, StringFragment argument, StringFragment text)
244             : fName(name)
245             , fArgument(argument)
246             , fText(text) {}
247 
248         StringFragment fName;
249         StringFragment fArgument;
250         StringFragment fText;
251     };
252 
253     struct NodeData {
254         // We use fBytes as a union which can hold any type of AST node, and use placement-new to
255         // copy AST objects into fBytes. Note that none of the AST objects have interesting
256         // destructors, so we do not bother doing a placement-delete on any of them in ~NodeData.
257         char fBytes[std::max({sizeof(Operator),
258                               sizeof(StringFragment),
259                               sizeof(bool),
260                               sizeof(SKSL_INT),
261                               sizeof(SKSL_FLOAT),
262                               sizeof(Modifiers),
263                               sizeof(FunctionData),
264                               sizeof(ParameterData),
265                               sizeof(VarData),
266                               sizeof(InterfaceBlockData),
267                               sizeof(SectionData)})];
268 
269         enum class Kind {
270             kOperator,
271             kStringFragment,
272             kBool,
273             kInt,
274             kFloat,
275             kModifiers,
276             kFunctionData,
277             kParameterData,
278             kVarData,
279             kInterfaceBlockData,
280             kSectionData
281         } fKind;
282 
283         NodeData() = default;
284 
NodeDataASTNode::NodeData285         NodeData(Operator op)
286             : fKind(Kind::kOperator) {
287             new (fBytes) Operator(op);
288         }
289 
NodeDataASTNode::NodeData290         NodeData(const StringFragment& data)
291             : fKind(Kind::kStringFragment) {
292             new (fBytes) StringFragment(data);
293         }
294 
NodeDataASTNode::NodeData295         NodeData(bool data)
296             : fKind(Kind::kBool) {
297             new (fBytes) bool(data);
298         }
299 
NodeDataASTNode::NodeData300         NodeData(SKSL_INT data)
301             : fKind(Kind::kInt) {
302             new (fBytes) SKSL_INT(data);
303         }
304 
NodeDataASTNode::NodeData305         NodeData(SKSL_FLOAT data)
306             : fKind(Kind::kFloat) {
307             new (fBytes) SKSL_FLOAT(data);
308         }
309 
NodeDataASTNode::NodeData310         NodeData(const Modifiers& data)
311             : fKind(Kind::kModifiers) {
312             new (fBytes) Modifiers(data);
313         }
314 
NodeDataASTNode::NodeData315         NodeData(const FunctionData& data)
316             : fKind(Kind::kFunctionData) {
317             new (fBytes) FunctionData(data);
318         }
319 
NodeDataASTNode::NodeData320         NodeData(const VarData& data)
321             : fKind(Kind::kVarData) {
322             new (fBytes) VarData(data);
323         }
324 
NodeDataASTNode::NodeData325         NodeData(const ParameterData& data)
326             : fKind(Kind::kParameterData) {
327             new (fBytes) ParameterData(data);
328         }
329 
NodeDataASTNode::NodeData330         NodeData(const InterfaceBlockData& data)
331             : fKind(Kind::kInterfaceBlockData) {
332             new (fBytes) InterfaceBlockData(data);
333         }
334 
NodeDataASTNode::NodeData335         NodeData(const SectionData& data)
336             : fKind(Kind::kSectionData) {
337             new (fBytes) SectionData(data);
338         }
339     };
340 
ASTNodeASTNode341     ASTNode()
342         : fOffset(-1)
343         , fKind(Kind::kNull) {}
344 
ASTNodeASTNode345     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind)
346         : fNodes(nodes)
347         , fOffset(offset)
348         , fKind(kind) {
349 
350         switch (kind) {
351             case Kind::kBinary:
352             case Kind::kPostfix:
353             case Kind::kPrefix:
354                 fData.fKind = NodeData::Kind::kOperator;
355                 break;
356 
357             case Kind::kBool:
358             case Kind::kIf:
359             case Kind::kSwitch:
360                 fData.fKind = NodeData::Kind::kBool;
361                 break;
362 
363             case Kind::kEnum:
364             case Kind::kEnumCase:
365             case Kind::kExtension:
366             case Kind::kField:
367             case Kind::kIdentifier:
368             case Kind::kScope:
369             case Kind::kType:
370                 fData.fKind = NodeData::Kind::kStringFragment;
371                 break;
372 
373             case Kind::kFloat:
374                 fData.fKind = NodeData::Kind::kFloat;
375                 break;
376 
377             case Kind::kFunction:
378                 fData.fKind = NodeData::Kind::kFunctionData;
379                 break;
380 
381             case Kind::kInt:
382                 fData.fKind = NodeData::Kind::kInt;
383                 break;
384 
385             case Kind::kInterfaceBlock:
386                 fData.fKind = NodeData::Kind::kInterfaceBlockData;
387                 break;
388 
389             case Kind::kModifiers:
390                 fData.fKind = NodeData::Kind::kModifiers;
391                 break;
392 
393             case Kind::kParameter:
394                 fData.fKind = NodeData::Kind::kParameterData;
395                 break;
396 
397             case Kind::kVarDeclaration:
398                 fData.fKind = NodeData::Kind::kVarData;
399                 break;
400 
401             default:
402                 break;
403         }
404     }
405 
ASTNodeASTNode406     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, Operator op)
407         : fNodes(nodes)
408         , fData(op)
409         , fOffset(offset)
410         , fKind(kind) {}
411 
ASTNodeASTNode412     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, StringFragment s)
413         : fNodes(nodes)
414         , fData(s)
415         , fOffset(offset)
416         , fKind(kind) {}
417 
ASTNodeASTNode418     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, const char* s)
419         : fNodes(nodes)
420         , fData(StringFragment(s))
421         , fOffset(offset)
422         , fKind(kind) {}
423 
ASTNodeASTNode424     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, bool b)
425         : fNodes(nodes)
426         , fData(b)
427         , fOffset(offset)
428         , fKind(kind) {}
429 
ASTNodeASTNode430     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, SKSL_INT i)
431         : fNodes(nodes)
432         , fData(i)
433         , fOffset(offset)
434         , fKind(kind) {}
435 
ASTNodeASTNode436     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, SKSL_FLOAT f)
437         : fNodes(nodes)
438         , fData(f)
439         , fOffset(offset)
440         , fKind(kind) {}
441 
ASTNodeASTNode442     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, Modifiers m)
443         : fNodes(nodes)
444         , fData(m)
445         , fOffset(offset)
446         , fKind(kind) {}
447 
ASTNodeASTNode448     ASTNode(std::vector<ASTNode>* nodes, int offset, Kind kind, SectionData s)
449         : fNodes(nodes)
450         , fData(s)
451         , fOffset(offset)
452         , fKind(kind) {}
453 
454     operator bool() const {
455         return fKind != Kind::kNull;
456     }
457 
getOperatorASTNode458     Operator getOperator() const {
459         SkASSERT(fData.fKind == NodeData::Kind::kOperator);
460         return *reinterpret_cast<const Operator*>(fData.fBytes);
461     }
462 
getBoolASTNode463     bool getBool() const {
464         SkASSERT(fData.fKind == NodeData::Kind::kBool);
465         return *reinterpret_cast<const bool*>(fData.fBytes);
466     }
467 
getIntASTNode468     SKSL_INT getInt() const {
469         SkASSERT(fData.fKind == NodeData::Kind::kInt);
470         return *reinterpret_cast<const SKSL_INT*>(fData.fBytes);
471     }
472 
getFloatASTNode473     SKSL_FLOAT getFloat() const {
474         SkASSERT(fData.fKind == NodeData::Kind::kFloat);
475         return *reinterpret_cast<const SKSL_FLOAT*>(fData.fBytes);
476     }
477 
getStringASTNode478     const StringFragment& getString() const {
479         SkASSERT(fData.fKind == NodeData::Kind::kStringFragment);
480         return *reinterpret_cast<const StringFragment*>(fData.fBytes);
481     }
482 
getModifiersASTNode483     const Modifiers& getModifiers() const {
484         SkASSERT(fData.fKind == NodeData::Kind::kModifiers);
485         return *reinterpret_cast<const Modifiers*>(fData.fBytes);
486     }
487 
setModifiersASTNode488     void setModifiers(const Modifiers& m) {
489         new (fData.fBytes) Modifiers(m);
490     }
491 
getParameterDataASTNode492     const ParameterData& getParameterData() const {
493         SkASSERT(fData.fKind == NodeData::Kind::kParameterData);
494         return *reinterpret_cast<const ParameterData*>(fData.fBytes);
495     }
496 
setParameterDataASTNode497     void setParameterData(const ASTNode::ParameterData& pd) {
498         new (fData.fBytes) ParameterData(pd);
499     }
500 
getVarDataASTNode501     const VarData& getVarData() const {
502         SkASSERT(fData.fKind == NodeData::Kind::kVarData);
503         return *reinterpret_cast<const VarData*>(fData.fBytes);
504     }
505 
setVarDataASTNode506     void setVarData(const ASTNode::VarData& vd) {
507         new (fData.fBytes) VarData(vd);
508     }
509 
getFunctionDataASTNode510     const FunctionData& getFunctionData() const {
511         SkASSERT(fData.fKind == NodeData::Kind::kFunctionData);
512         return *reinterpret_cast<const FunctionData*>(fData.fBytes);
513     }
514 
setFunctionDataASTNode515     void setFunctionData(const ASTNode::FunctionData& fd) {
516         new (fData.fBytes) FunctionData(fd);
517     }
518 
getInterfaceBlockDataASTNode519     const InterfaceBlockData& getInterfaceBlockData() const {
520         SkASSERT(fData.fKind == NodeData::Kind::kInterfaceBlockData);
521         return *reinterpret_cast<const InterfaceBlockData*>(fData.fBytes);
522     }
523 
setInterfaceBlockDataASTNode524     void setInterfaceBlockData(const ASTNode::InterfaceBlockData& id) {
525         new (fData.fBytes) InterfaceBlockData(id);
526     }
527 
getSectionDataASTNode528     const SectionData& getSectionData() const {
529         SkASSERT(fData.fKind == NodeData::Kind::kSectionData);
530         return *reinterpret_cast<const SectionData*>(fData.fBytes);
531     }
532 
addChildASTNode533     void addChild(ID id) {
534         SkASSERT(!(*fNodes)[id.fValue].fNext);
535         if (fLastChild) {
536             SkASSERT(!(*fNodes)[fLastChild.fValue].fNext);
537             (*fNodes)[fLastChild.fValue].fNext = id;
538         } else {
539             fFirstChild = id;
540         }
541         fLastChild = id;
542         SkASSERT(!(*fNodes)[fLastChild.fValue].fNext);
543     }
544 
beginASTNode545     iterator begin() const {
546         return iterator(fNodes, fFirstChild);
547     }
548 
endASTNode549     iterator end() const {
550         return iterator(fNodes, ID(-1));
551     }
552 
553 #ifdef SK_DEBUG
554     String description() const;
555 #endif
556 
557     std::vector<ASTNode>* fNodes;
558 
559     NodeData fData;
560 
561     int fOffset;
562 
563     Kind fKind;
564 
565     ID fFirstChild;
566 
567     ID fLastChild;
568 
569     ID fNext;
570 };
571 
572 }  // namespace SkSL
573 
574 #endif
575