1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <functional>
8 #include <unordered_map>
9 #include <unordered_set>
10 #include <vector>
11 
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
14 #include "compiler/translator/TranslatorMetalDirect/ToposortStructs.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 
18 using namespace sh;
19 
20 ////////////////////////////////////////////////////////////////////////////////
21 
22 namespace
23 {
24 
25 template <typename T>
26 using Edges = std::unordered_set<T>;
27 
28 template <typename T>
29 using Graph = std::unordered_map<T, Edges<T>>;
30 
BuildGraphImpl(SymbolEnv & symbolEnv,Graph<const TStructure * > & g,const TStructure * s)31 void BuildGraphImpl(SymbolEnv &symbolEnv, Graph<const TStructure *> &g, const TStructure *s)
32 {
33     if (g.find(s) != g.end())
34     {
35         return;
36     }
37 
38     Edges<const TStructure *> &es = g[s];
39 
40     const TFieldList &fs = s->fields();
41     for (const TField *f : fs)
42     {
43         if (const TStructure *z = symbolEnv.remap(f->type()->getStruct()))
44         {
45             es.insert(z);
46             BuildGraphImpl(symbolEnv, g, z);
47             Edges<const TStructure *> &ez = g[z];
48             es.insert(ez.begin(), ez.end());
49         }
50     }
51 }
52 
BuildGraph(SymbolEnv & symbolEnv,const std::vector<const TStructure * > & structs)53 Graph<const TStructure *> BuildGraph(SymbolEnv &symbolEnv,
54                                      const std::vector<const TStructure *> &structs)
55 {
56     Graph<const TStructure *> g;
57     for (const TStructure *s : structs)
58     {
59         BuildGraphImpl(symbolEnv, g, s);
60     }
61     return g;
62 }
63 
64 // Algorthm: https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
65 template <typename T>
Toposort(const Graph<T> & g)66 std::vector<T> Toposort(const Graph<T> &g)
67 {
68     // nodes with temporary mark
69     std::unordered_set<T> temps;
70 
71     // nodes without permanent mark
72     std::unordered_set<T> invPerms;
73     for (const auto &entry : g)
74     {
75         invPerms.insert(entry.first);
76     }
77 
78     // L <- Empty list that will contain the sorted elements
79     std::vector<T> L;
80 
81     // function visit(node n)
82     std::function<void(T)> visit = [&](T n) -> void {
83         // if n has a permanent mark then
84         if (invPerms.find(n) == invPerms.end())
85         {
86             // return
87             return;
88         }
89         // if n has a temporary mark then
90         if (temps.find(n) != temps.end())
91         {
92             // stop   (not a DAG)
93             UNREACHABLE();
94         }
95 
96         // mark n with a temporary mark
97         temps.insert(n);
98 
99         // for each node m with an edge from n to m do
100         auto enIter = g.find(n);
101         ASSERT(enIter != g.end());
102         const Edges<T> &en = enIter->second;
103         for (T m : en)
104         {
105             // visit(m)
106             visit(m);
107         }
108 
109         // remove temporary mark from n
110         temps.erase(n);
111         // mark n with a permanent mark
112         invPerms.erase(n);
113         // add n to head of L
114         L.push_back(n);
115     };
116 
117     // while exists nodes without a permanent mark do
118     while (!invPerms.empty())
119     {
120         // select an unmarked node n
121         T n = *invPerms.begin();
122         // visit(n)
123         visit(n);
124     }
125 
126     return L;
127 }
128 
CreateStructEqualityFunction(TSymbolTable & symbolTable,const TStructure & aStructType)129 TIntermFunctionDefinition *CreateStructEqualityFunction(TSymbolTable &symbolTable,
130                                                         const TStructure &aStructType)
131 {
132     ////////////////////
133 
134     auto &funcEquality =
135         *new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
136                        new TType(TBasicType::EbtBool), true);
137     auto &aStruct = CreateInstanceVariable(symbolTable, aStructType, Name("a"));
138     auto &bStruct = CreateInstanceVariable(symbolTable, aStructType, Name("b"));
139     funcEquality.addParameter(&aStruct);
140     funcEquality.addParameter(&bStruct);
141 
142     auto &bodyEquality = *new TIntermBlock();
143     std::vector<TIntermTyped *> andNodes;
144     ////////////////////
145 
146     const TFieldList &aFields = aStructType.fields();
147     const size_t size         = aFields.size();
148 
149     auto testEquality = [&](TIntermTyped &a, TIntermTyped &b) -> TIntermTyped * {
150         ASSERT(a.getType() == b.getType());
151         const TType &type = a.getType();
152         if (type.isVector() || type.isMatrix() || type.getStruct())
153         {
154             auto *func =
155                 new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
156                               new TType(TBasicType::EbtBool), true);
157             return TIntermAggregate::CreateFunctionCall(*func, new TIntermSequence{&a, &b});
158         }
159         else
160         {
161             return new TIntermBinary(TOperator::EOpEqual, &a, &b);
162         }
163     };
164 
165     for (size_t idx = 0; idx < size; ++idx)
166     {
167         const TField &aField    = *aFields[idx];
168         const TType &aFieldType = *aField.type();
169         auto &aFieldName        = aField.name();
170 
171         if (aFieldType.isArray())
172         {
173             ASSERT(!aFieldType.isArrayOfArrays());  // TODO
174             int dim = aFieldType.getOutermostArraySize();
175             for (int d = 0; d < dim; ++d)
176             {
177                 auto &aAccess = AccessIndex(AccessField(aStruct, aFieldName), d);
178                 auto &bAccess = AccessIndex(AccessField(bStruct, aFieldName), d);
179                 auto *eqNode  = testEquality(bAccess, aAccess);
180                 andNodes.push_back(eqNode);
181             }
182         }
183         else
184         {
185             auto &aAccess = AccessField(aStruct, aFieldName);
186             auto &bAccess = AccessField(bStruct, aFieldName);
187             auto *eqNode  = testEquality(bAccess, aAccess);
188             andNodes.push_back(eqNode);
189         }
190     }
191 
192     ASSERT(andNodes.size() > 0);  // Empty structs are not allowed in GLSL
193     TIntermTyped *outNode = andNodes.back();
194     andNodes.pop_back();
195     for (TIntermTyped *andNode : andNodes)
196     {
197         outNode = new TIntermBinary(TOperator::EOpLogicalAnd, andNode, outNode);
198     }
199     bodyEquality.appendStatement(new TIntermBranch(TOperator::EOpReturn, outNode));
200     auto *funcProtoEquality = new TIntermFunctionPrototype(&funcEquality);
201     return new TIntermFunctionDefinition(funcProtoEquality, &bodyEquality);
202 }
203 
204 struct DeclaredStructure
205 {
206     TIntermDeclaration *declNode;
207     TIntermFunctionDefinition *equalityFunctionDefinition;
208     const TStructure *structure;
209 };
210 
GetAsDeclaredStructure(SymbolEnv & symbolEnv,TIntermNode & node,DeclaredStructure & out,TSymbolTable & symbolTable,const std::unordered_set<const TStructure * > & usedStructs)211 bool GetAsDeclaredStructure(SymbolEnv &symbolEnv,
212                             TIntermNode &node,
213                             DeclaredStructure &out,
214                             TSymbolTable &symbolTable,
215                             const std::unordered_set<const TStructure *> &usedStructs)
216 {
217     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
218     {
219         ASSERT(declNode->getChildCount() == 1);
220         TIntermNode &childNode = *declNode->getChildNode(0);
221 
222         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
223         {
224             const TVariable &var = symbolNode->variable();
225             const TType &type    = var.getType();
226             if (const TStructure *structure = symbolEnv.remap(type.getStruct()))
227             {
228                 if (type.isStructSpecifier())
229                 {
230                     out.declNode  = declNode;
231                     out.structure = structure;
232                     out.equalityFunctionDefinition =
233                         usedStructs.find(structure) == usedStructs.end()
234                             ? nullptr
235                             : CreateStructEqualityFunction(symbolTable, *structure);
236                     return true;
237                 }
238             }
239         }
240     }
241     return false;
242 }
243 
244 class FindStructEqualityUse : public TIntermTraverser
245 {
246   public:
247     SymbolEnv &mSymbolEnv;
248     std::unordered_set<const TStructure *> mUsedStructs;
249 
FindStructEqualityUse(SymbolEnv & symbolEnv)250     FindStructEqualityUse(SymbolEnv &symbolEnv)
251         : TIntermTraverser(false, false, true), mSymbolEnv(symbolEnv)
252     {}
253 
visitBinary(Visit,TIntermBinary * binary)254     bool visitBinary(Visit, TIntermBinary *binary) override
255     {
256         const TOperator op = binary->getOp();
257 
258         switch (op)
259         {
260             case TOperator::EOpEqual:
261             case TOperator::EOpNotEqual:
262             {
263                 const TType &leftType  = binary->getLeft()->getType();
264                 const TType &rightType = binary->getRight()->getType();
265                 ASSERT(leftType.getStruct() == rightType.getStruct());
266                 if (const TStructure *structure = mSymbolEnv.remap(leftType.getStruct()))
267                 {
268                     useStruct(*structure);
269                 }
270             }
271             break;
272 
273             default:
274                 break;
275         }
276 
277         return true;
278     }
279 
280   private:
useStruct(const TStructure & structure)281     void useStruct(const TStructure &structure)
282     {
283         if (mUsedStructs.insert(&structure).second)
284         {
285             for (const TField *field : structure.fields())
286             {
287                 if (const TStructure *subStruct = mSymbolEnv.remap(field->type()->getStruct()))
288                 {
289                     useStruct(*subStruct);
290                 }
291             }
292         }
293     }
294 };
295 
296 }  // anonymous namespace
297 
298 ////////////////////////////////////////////////////////////////////////////////
299 
ToposortStructs(TCompiler & compiler,SymbolEnv & symbolEnv,TIntermBlock & root,ProgramPreludeConfig & ppc)300 bool sh::ToposortStructs(TCompiler &compiler,
301                          SymbolEnv &symbolEnv,
302                          TIntermBlock &root,
303                          ProgramPreludeConfig &ppc)
304 {
305     FindStructEqualityUse finder(symbolEnv);
306     root.traverse(&finder);
307     ppc.hasStructEq = !finder.mUsedStructs.empty();
308 
309     std::vector<DeclaredStructure> declaredStructs;
310     std::vector<TIntermNode *> nonStructStmtNodes;
311 
312     {
313         DeclaredStructure declaredStruct;
314         const size_t stmtCount = root.getChildCount();
315         for (size_t i = 0; i < stmtCount; ++i)
316         {
317             TIntermNode &stmtNode = *root.getChildNode(i);
318             if (GetAsDeclaredStructure(symbolEnv, stmtNode, declaredStruct,
319                                        compiler.getSymbolTable(), finder.mUsedStructs))
320             {
321                 declaredStructs.push_back(declaredStruct);
322             }
323             else
324             {
325                 nonStructStmtNodes.push_back(&stmtNode);
326             }
327         }
328     }
329 
330     {
331         std::vector<const TStructure *> structs;
332         std::unordered_map<const TStructure *, DeclaredStructure> rawToDeclared;
333 
334         for (const DeclaredStructure &d : declaredStructs)
335         {
336             structs.push_back(d.structure);
337             ASSERT(rawToDeclared.find(d.structure) == rawToDeclared.end());
338             rawToDeclared[d.structure] = d;
339         }
340 
341         // Note: Graph may contain more than only explicitly declared structures.
342         Graph<const TStructure *> g                   = BuildGraph(symbolEnv, structs);
343         std::vector<const TStructure *> sortedStructs = Toposort(g);
344         ASSERT(declaredStructs.size() <= sortedStructs.size());
345 
346         declaredStructs.clear();
347         for (const TStructure *s : sortedStructs)
348         {
349             auto it = rawToDeclared.find(s);
350             if (it != rawToDeclared.end())
351             {
352                 auto &d = it->second;
353                 ASSERT(d.declNode);
354                 declaredStructs.push_back(d);
355             }
356         }
357     }
358 
359     {
360         TIntermSequence newStmtNodes;
361 
362         for (DeclaredStructure &declaredStruct : declaredStructs)
363         {
364             ASSERT(declaredStruct.declNode);
365             newStmtNodes.push_back(declaredStruct.declNode);
366             if (declaredStruct.equalityFunctionDefinition)
367             {
368                 newStmtNodes.push_back(declaredStruct.equalityFunctionDefinition);
369             }
370         }
371 
372         for (TIntermNode *stmtNode : nonStructStmtNodes)
373         {
374             ASSERT(stmtNode);
375             newStmtNodes.push_back(stmtNode);
376         }
377 
378         *root.getSequence() = newStmtNodes;
379     }
380 
381     return compiler.validateAST(&root);
382 }
383