1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <functional>
8 #include <unordered_map>
9 #include <unordered_set>
10 #include <vector>
11
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
14 #include "compiler/translator/TranslatorMetalDirect/ToposortStructs.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17
18 using namespace sh;
19
20 ////////////////////////////////////////////////////////////////////////////////
21
22 namespace
23 {
24
25 template <typename T>
26 using Edges = std::unordered_set<T>;
27
28 template <typename T>
29 using Graph = std::unordered_map<T, Edges<T>>;
30
BuildGraphImpl(SymbolEnv & symbolEnv,Graph<const TStructure * > & g,const TStructure * s)31 void BuildGraphImpl(SymbolEnv &symbolEnv, Graph<const TStructure *> &g, const TStructure *s)
32 {
33 if (g.find(s) != g.end())
34 {
35 return;
36 }
37
38 Edges<const TStructure *> &es = g[s];
39
40 const TFieldList &fs = s->fields();
41 for (const TField *f : fs)
42 {
43 if (const TStructure *z = symbolEnv.remap(f->type()->getStruct()))
44 {
45 es.insert(z);
46 BuildGraphImpl(symbolEnv, g, z);
47 Edges<const TStructure *> &ez = g[z];
48 es.insert(ez.begin(), ez.end());
49 }
50 }
51 }
52
BuildGraph(SymbolEnv & symbolEnv,const std::vector<const TStructure * > & structs)53 Graph<const TStructure *> BuildGraph(SymbolEnv &symbolEnv,
54 const std::vector<const TStructure *> &structs)
55 {
56 Graph<const TStructure *> g;
57 for (const TStructure *s : structs)
58 {
59 BuildGraphImpl(symbolEnv, g, s);
60 }
61 return g;
62 }
63
64 // Algorthm: https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
65 template <typename T>
Toposort(const Graph<T> & g)66 std::vector<T> Toposort(const Graph<T> &g)
67 {
68 // nodes with temporary mark
69 std::unordered_set<T> temps;
70
71 // nodes without permanent mark
72 std::unordered_set<T> invPerms;
73 for (const auto &entry : g)
74 {
75 invPerms.insert(entry.first);
76 }
77
78 // L <- Empty list that will contain the sorted elements
79 std::vector<T> L;
80
81 // function visit(node n)
82 std::function<void(T)> visit = [&](T n) -> void {
83 // if n has a permanent mark then
84 if (invPerms.find(n) == invPerms.end())
85 {
86 // return
87 return;
88 }
89 // if n has a temporary mark then
90 if (temps.find(n) != temps.end())
91 {
92 // stop (not a DAG)
93 UNREACHABLE();
94 }
95
96 // mark n with a temporary mark
97 temps.insert(n);
98
99 // for each node m with an edge from n to m do
100 auto enIter = g.find(n);
101 ASSERT(enIter != g.end());
102 const Edges<T> &en = enIter->second;
103 for (T m : en)
104 {
105 // visit(m)
106 visit(m);
107 }
108
109 // remove temporary mark from n
110 temps.erase(n);
111 // mark n with a permanent mark
112 invPerms.erase(n);
113 // add n to head of L
114 L.push_back(n);
115 };
116
117 // while exists nodes without a permanent mark do
118 while (!invPerms.empty())
119 {
120 // select an unmarked node n
121 T n = *invPerms.begin();
122 // visit(n)
123 visit(n);
124 }
125
126 return L;
127 }
128
CreateStructEqualityFunction(TSymbolTable & symbolTable,const TStructure & aStructType)129 TIntermFunctionDefinition *CreateStructEqualityFunction(TSymbolTable &symbolTable,
130 const TStructure &aStructType)
131 {
132 ////////////////////
133
134 auto &funcEquality =
135 *new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
136 new TType(TBasicType::EbtBool), true);
137 auto &aStruct = CreateInstanceVariable(symbolTable, aStructType, Name("a"));
138 auto &bStruct = CreateInstanceVariable(symbolTable, aStructType, Name("b"));
139 funcEquality.addParameter(&aStruct);
140 funcEquality.addParameter(&bStruct);
141
142 auto &bodyEquality = *new TIntermBlock();
143 std::vector<TIntermTyped *> andNodes;
144 ////////////////////
145
146 const TFieldList &aFields = aStructType.fields();
147 const size_t size = aFields.size();
148
149 auto testEquality = [&](TIntermTyped &a, TIntermTyped &b) -> TIntermTyped * {
150 ASSERT(a.getType() == b.getType());
151 const TType &type = a.getType();
152 if (type.isVector() || type.isMatrix() || type.getStruct())
153 {
154 auto *func =
155 new TFunction(&symbolTable, ImmutableString("equal"), SymbolType::AngleInternal,
156 new TType(TBasicType::EbtBool), true);
157 return TIntermAggregate::CreateFunctionCall(*func, new TIntermSequence{&a, &b});
158 }
159 else
160 {
161 return new TIntermBinary(TOperator::EOpEqual, &a, &b);
162 }
163 };
164
165 for (size_t idx = 0; idx < size; ++idx)
166 {
167 const TField &aField = *aFields[idx];
168 const TType &aFieldType = *aField.type();
169 auto &aFieldName = aField.name();
170
171 if (aFieldType.isArray())
172 {
173 ASSERT(!aFieldType.isArrayOfArrays()); // TODO
174 int dim = aFieldType.getOutermostArraySize();
175 for (int d = 0; d < dim; ++d)
176 {
177 auto &aAccess = AccessIndex(AccessField(aStruct, aFieldName), d);
178 auto &bAccess = AccessIndex(AccessField(bStruct, aFieldName), d);
179 auto *eqNode = testEquality(bAccess, aAccess);
180 andNodes.push_back(eqNode);
181 }
182 }
183 else
184 {
185 auto &aAccess = AccessField(aStruct, aFieldName);
186 auto &bAccess = AccessField(bStruct, aFieldName);
187 auto *eqNode = testEquality(bAccess, aAccess);
188 andNodes.push_back(eqNode);
189 }
190 }
191
192 ASSERT(andNodes.size() > 0); // Empty structs are not allowed in GLSL
193 TIntermTyped *outNode = andNodes.back();
194 andNodes.pop_back();
195 for (TIntermTyped *andNode : andNodes)
196 {
197 outNode = new TIntermBinary(TOperator::EOpLogicalAnd, andNode, outNode);
198 }
199 bodyEquality.appendStatement(new TIntermBranch(TOperator::EOpReturn, outNode));
200 auto *funcProtoEquality = new TIntermFunctionPrototype(&funcEquality);
201 return new TIntermFunctionDefinition(funcProtoEquality, &bodyEquality);
202 }
203
204 struct DeclaredStructure
205 {
206 TIntermDeclaration *declNode;
207 TIntermFunctionDefinition *equalityFunctionDefinition;
208 const TStructure *structure;
209 };
210
GetAsDeclaredStructure(SymbolEnv & symbolEnv,TIntermNode & node,DeclaredStructure & out,TSymbolTable & symbolTable,const std::unordered_set<const TStructure * > & usedStructs)211 bool GetAsDeclaredStructure(SymbolEnv &symbolEnv,
212 TIntermNode &node,
213 DeclaredStructure &out,
214 TSymbolTable &symbolTable,
215 const std::unordered_set<const TStructure *> &usedStructs)
216 {
217 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
218 {
219 ASSERT(declNode->getChildCount() == 1);
220 TIntermNode &childNode = *declNode->getChildNode(0);
221
222 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
223 {
224 const TVariable &var = symbolNode->variable();
225 const TType &type = var.getType();
226 if (const TStructure *structure = symbolEnv.remap(type.getStruct()))
227 {
228 if (type.isStructSpecifier())
229 {
230 out.declNode = declNode;
231 out.structure = structure;
232 out.equalityFunctionDefinition =
233 usedStructs.find(structure) == usedStructs.end()
234 ? nullptr
235 : CreateStructEqualityFunction(symbolTable, *structure);
236 return true;
237 }
238 }
239 }
240 }
241 return false;
242 }
243
244 class FindStructEqualityUse : public TIntermTraverser
245 {
246 public:
247 SymbolEnv &mSymbolEnv;
248 std::unordered_set<const TStructure *> mUsedStructs;
249
FindStructEqualityUse(SymbolEnv & symbolEnv)250 FindStructEqualityUse(SymbolEnv &symbolEnv)
251 : TIntermTraverser(false, false, true), mSymbolEnv(symbolEnv)
252 {}
253
visitBinary(Visit,TIntermBinary * binary)254 bool visitBinary(Visit, TIntermBinary *binary) override
255 {
256 const TOperator op = binary->getOp();
257
258 switch (op)
259 {
260 case TOperator::EOpEqual:
261 case TOperator::EOpNotEqual:
262 {
263 const TType &leftType = binary->getLeft()->getType();
264 const TType &rightType = binary->getRight()->getType();
265 ASSERT(leftType.getStruct() == rightType.getStruct());
266 if (const TStructure *structure = mSymbolEnv.remap(leftType.getStruct()))
267 {
268 useStruct(*structure);
269 }
270 }
271 break;
272
273 default:
274 break;
275 }
276
277 return true;
278 }
279
280 private:
useStruct(const TStructure & structure)281 void useStruct(const TStructure &structure)
282 {
283 if (mUsedStructs.insert(&structure).second)
284 {
285 for (const TField *field : structure.fields())
286 {
287 if (const TStructure *subStruct = mSymbolEnv.remap(field->type()->getStruct()))
288 {
289 useStruct(*subStruct);
290 }
291 }
292 }
293 }
294 };
295
296 } // anonymous namespace
297
298 ////////////////////////////////////////////////////////////////////////////////
299
ToposortStructs(TCompiler & compiler,SymbolEnv & symbolEnv,TIntermBlock & root,ProgramPreludeConfig & ppc)300 bool sh::ToposortStructs(TCompiler &compiler,
301 SymbolEnv &symbolEnv,
302 TIntermBlock &root,
303 ProgramPreludeConfig &ppc)
304 {
305 FindStructEqualityUse finder(symbolEnv);
306 root.traverse(&finder);
307 ppc.hasStructEq = !finder.mUsedStructs.empty();
308
309 std::vector<DeclaredStructure> declaredStructs;
310 std::vector<TIntermNode *> nonStructStmtNodes;
311
312 {
313 DeclaredStructure declaredStruct;
314 const size_t stmtCount = root.getChildCount();
315 for (size_t i = 0; i < stmtCount; ++i)
316 {
317 TIntermNode &stmtNode = *root.getChildNode(i);
318 if (GetAsDeclaredStructure(symbolEnv, stmtNode, declaredStruct,
319 compiler.getSymbolTable(), finder.mUsedStructs))
320 {
321 declaredStructs.push_back(declaredStruct);
322 }
323 else
324 {
325 nonStructStmtNodes.push_back(&stmtNode);
326 }
327 }
328 }
329
330 {
331 std::vector<const TStructure *> structs;
332 std::unordered_map<const TStructure *, DeclaredStructure> rawToDeclared;
333
334 for (const DeclaredStructure &d : declaredStructs)
335 {
336 structs.push_back(d.structure);
337 ASSERT(rawToDeclared.find(d.structure) == rawToDeclared.end());
338 rawToDeclared[d.structure] = d;
339 }
340
341 // Note: Graph may contain more than only explicitly declared structures.
342 Graph<const TStructure *> g = BuildGraph(symbolEnv, structs);
343 std::vector<const TStructure *> sortedStructs = Toposort(g);
344 ASSERT(declaredStructs.size() <= sortedStructs.size());
345
346 declaredStructs.clear();
347 for (const TStructure *s : sortedStructs)
348 {
349 auto it = rawToDeclared.find(s);
350 if (it != rawToDeclared.end())
351 {
352 auto &d = it->second;
353 ASSERT(d.declNode);
354 declaredStructs.push_back(d);
355 }
356 }
357 }
358
359 {
360 TIntermSequence newStmtNodes;
361
362 for (DeclaredStructure &declaredStruct : declaredStructs)
363 {
364 ASSERT(declaredStruct.declNode);
365 newStmtNodes.push_back(declaredStruct.declNode);
366 if (declaredStruct.equalityFunctionDefinition)
367 {
368 newStmtNodes.push_back(declaredStruct.equalityFunctionDefinition);
369 }
370 }
371
372 for (TIntermNode *stmtNode : nonStructStmtNodes)
373 {
374 ASSERT(stmtNode);
375 newStmtNodes.push_back(stmtNode);
376 }
377
378 *root.getSequence() = newStmtNodes;
379 }
380
381 return compiler.validateAST(&root);
382 }
383