1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cctype>
8 #include <cstring>
9 #include <limits>
10 #include <map>
11 #include <unordered_map>
12 #include <unordered_set>
13
14 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
15 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
16 #include "compiler/translator/TranslatorMetalDirect/RewriteKeywords.h"
17
18 using namespace sh;
19
20 ////////////////////////////////////////////////////////////////////////////////
21
22 namespace
23 {
24
25 template <typename T>
26 using Remapping = std::unordered_map<const T *, const T *>;
27
28 class Rewriter : public TIntermRebuild
29 {
30 private:
31 const std::set<ImmutableString> &mKeywords;
32 IdGen &mIdGen;
33 Remapping<TField> modifiedFields;
34 Remapping<TFieldList> mFieldLists;
35 Remapping<TFunction> mFunctions;
36 Remapping<TInterfaceBlock> mInterfaceBlocks;
37 Remapping<TStructure> mStructures;
38 Remapping<TVariable> mVariables;
39 std::map<ImmutableString, std::string> mPredefinedNames;
40 std::string mNewNameBuffer;
41
42 private:
43 template <typename T>
maybeCreateNewName(T const & object)44 ImmutableString maybeCreateNewName(T const &object)
45 {
46 if (needsRenaming(object, false))
47 {
48 auto it = mPredefinedNames.find(Name(object).rawName());
49 if (it != mPredefinedNames.end())
50 {
51 return ImmutableString(it->second);
52 }
53 return mIdGen.createNewName(Name(object)).rawName();
54 }
55 return Name(object).rawName();
56 }
57
createRenamed(const TField & field)58 const TField *createRenamed(const TField &field)
59 {
60 auto *renamed =
61 new TField(const_cast<TType *>(&getRenamedOrOriginal(*field.type())),
62 maybeCreateNewName(field), field.line(), SymbolType::AngleInternal);
63
64 return renamed;
65 }
66
createRenamed(const TFieldList & fieldList)67 const TFieldList *createRenamed(const TFieldList &fieldList)
68 {
69 auto *renamed = new TFieldList();
70 for (const TField *field : fieldList)
71 {
72 renamed->push_back(const_cast<TField *>(&getRenamedOrOriginal(*field)));
73 }
74 return renamed;
75 }
76
createRenamed(const TFunction & function)77 const TFunction *createRenamed(const TFunction &function)
78 {
79 auto *renamed =
80 new TFunction(&mSymbolTable, maybeCreateNewName(function), SymbolType::AngleInternal,
81 &getRenamedOrOriginal(function.getReturnType()),
82 function.isKnownToNotHaveSideEffects());
83
84 const size_t paramCount = function.getParamCount();
85 for (size_t i = 0; i < paramCount; ++i)
86 {
87 const TVariable ¶m = *function.getParam(i);
88 renamed->addParameter(&getRenamedOrOriginal(param));
89 }
90
91 if (function.isDefined())
92 {
93 renamed->setDefined();
94 }
95
96 if (function.hasPrototypeDeclaration())
97 {
98 renamed->setHasPrototypeDeclaration();
99 }
100
101 return renamed;
102 }
103
createRenamed(const TInterfaceBlock & interfaceBlock)104 const TInterfaceBlock *createRenamed(const TInterfaceBlock &interfaceBlock)
105 {
106 TLayoutQualifier layoutQualifier = TLayoutQualifier::Create();
107 layoutQualifier.blockStorage = interfaceBlock.blockStorage();
108 layoutQualifier.binding = interfaceBlock.blockBinding();
109
110 auto *renamed =
111 new TInterfaceBlock(&mSymbolTable, maybeCreateNewName(interfaceBlock),
112 &getRenamedOrOriginal(interfaceBlock.fields()), layoutQualifier,
113 SymbolType::AngleInternal, interfaceBlock.extensions());
114
115 return renamed;
116 }
117
createRenamed(const TStructure & structure)118 const TStructure *createRenamed(const TStructure &structure)
119 {
120 auto *renamed =
121 new TStructure(&mSymbolTable, maybeCreateNewName(structure),
122 &getRenamedOrOriginal(structure.fields()), SymbolType::AngleInternal);
123
124 renamed->setAtGlobalScope(structure.atGlobalScope());
125
126 return renamed;
127 }
128
createRenamed(const TType & type)129 const TType *createRenamed(const TType &type)
130 {
131 TType *renamed;
132
133 if (const TStructure *structure = type.getStruct())
134 {
135 renamed = new TType(&getRenamedOrOriginal(*structure), type.isStructSpecifier());
136 }
137 else if (const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock())
138 {
139 renamed = new TType(&getRenamedOrOriginal(*interfaceBlock), type.getQualifier(),
140 type.getLayoutQualifier());
141 }
142 else
143 {
144 UNREACHABLE(); // Can't rename built-in types.
145 renamed = nullptr;
146 }
147
148 if (type.isArray())
149 {
150 renamed->makeArrays(type.getArraySizes());
151 }
152 renamed->setPrecise(type.isPrecise());
153 renamed->setInvariant(type.isInvariant());
154 renamed->setMemoryQualifier(type.getMemoryQualifier());
155 renamed->setLayoutQualifier(type.getLayoutQualifier());
156
157 return renamed;
158 }
159
createRenamed(const TVariable & variable)160 const TVariable *createRenamed(const TVariable &variable)
161 {
162 auto *renamed = new TVariable(&mSymbolTable, maybeCreateNewName(variable),
163 &getRenamedOrOriginal(variable.getType()),
164 SymbolType::AngleInternal, variable.extensions());
165
166 return renamed;
167 }
168
169 template <typename T>
tryGetRenamedImpl(const T & object,Remapping<T> * remapping)170 const T *tryGetRenamedImpl(const T &object, Remapping<T> *remapping)
171 {
172 if (!needsRenaming(object, true))
173 {
174 return nullptr;
175 }
176
177 if (remapping)
178 {
179 auto it = remapping->find(&object);
180 if (it != remapping->end())
181 {
182 return it->second;
183 }
184 }
185
186 const T *renamedObject = createRenamed(object);
187
188 if (remapping)
189 {
190 (*remapping)[&object] = renamedObject;
191 }
192
193 return renamedObject;
194 }
195
tryGetRenamed(const TField & field)196 const TField *tryGetRenamed(const TField &field)
197 {
198 return tryGetRenamedImpl(field, &modifiedFields);
199 }
200
tryGetRenamed(const TFieldList & fieldList)201 const TFieldList *tryGetRenamed(const TFieldList &fieldList)
202 {
203 return tryGetRenamedImpl(fieldList, &mFieldLists);
204 }
205
tryGetRenamed(const TFunction & func)206 const TFunction *tryGetRenamed(const TFunction &func)
207 {
208 return tryGetRenamedImpl(func, &mFunctions);
209 }
210
tryGetRenamed(const TInterfaceBlock & interfaceBlock)211 const TInterfaceBlock *tryGetRenamed(const TInterfaceBlock &interfaceBlock)
212 {
213 return tryGetRenamedImpl(interfaceBlock, &mInterfaceBlocks);
214 }
215
tryGetRenamed(const TStructure & structure)216 const TStructure *tryGetRenamed(const TStructure &structure)
217 {
218 return tryGetRenamedImpl(structure, &mStructures);
219 }
220
tryGetRenamed(const TType & type)221 const TType *tryGetRenamed(const TType &type)
222 {
223 return tryGetRenamedImpl(type, static_cast<Remapping<TType> *>(nullptr));
224 }
225
tryGetRenamed(const TVariable & variable)226 const TVariable *tryGetRenamed(const TVariable &variable)
227 {
228 return tryGetRenamedImpl(variable, &mVariables);
229 }
230
231 template <typename T>
getRenamedOrOriginal(const T & object)232 const T &getRenamedOrOriginal(const T &object)
233 {
234 const T *renamed = tryGetRenamed(object);
235 if (renamed)
236 {
237 return *renamed;
238 }
239 return object;
240 }
241
242 template <typename T>
needsRenamingImpl(const T & object) const243 bool needsRenamingImpl(const T &object) const
244 {
245 const SymbolType symbolType = object.symbolType();
246 switch (symbolType)
247 {
248 case SymbolType::BuiltIn:
249 case SymbolType::AngleInternal:
250 case SymbolType::Empty:
251 return false;
252
253 case SymbolType::UserDefined:
254 break;
255 }
256
257 const ImmutableString name = Name(object).rawName();
258 if (mKeywords.find(name) != mKeywords.end())
259 {
260 return true;
261 }
262
263 if (name.beginsWith(kAngleInternalPrefix))
264 {
265 return true;
266 }
267
268 return false;
269 }
270
needsRenaming(const TField & field,bool recursive) const271 bool needsRenaming(const TField &field, bool recursive) const
272 {
273 return needsRenamingImpl(field) || (recursive && needsRenaming(*field.type(), true));
274 }
275
needsRenaming(const TFieldList & fieldList,bool recursive) const276 bool needsRenaming(const TFieldList &fieldList, bool recursive) const
277 {
278 ASSERT(recursive);
279 for (const TField *field : fieldList)
280 {
281 if (needsRenaming(*field, true))
282 {
283 return true;
284 }
285 }
286 return false;
287 }
288
needsRenaming(const TFunction & function,bool recursive) const289 bool needsRenaming(const TFunction &function, bool recursive) const
290 {
291 if (needsRenamingImpl(function))
292 {
293 return true;
294 }
295
296 if (!recursive)
297 {
298 return false;
299 }
300
301 const size_t paramCount = function.getParamCount();
302 for (size_t i = 0; i < paramCount; ++i)
303 {
304 const TVariable ¶m = *function.getParam(i);
305 if (needsRenaming(param, true))
306 {
307 return true;
308 }
309 }
310
311 return false;
312 }
313
needsRenaming(const TInterfaceBlock & interfaceBlock,bool recursive) const314 bool needsRenaming(const TInterfaceBlock &interfaceBlock, bool recursive) const
315 {
316 return needsRenamingImpl(interfaceBlock) ||
317 (recursive && needsRenaming(interfaceBlock.fields(), true));
318 }
319
needsRenaming(const TStructure & structure,bool recursive) const320 bool needsRenaming(const TStructure &structure, bool recursive) const
321 {
322 return needsRenamingImpl(structure) ||
323 (recursive && needsRenaming(structure.fields(), true));
324 }
325
needsRenaming(const TType & type,bool recursive) const326 bool needsRenaming(const TType &type, bool recursive) const
327 {
328 if (const TStructure *structure = type.getStruct())
329 {
330 return needsRenaming(*structure, recursive);
331 }
332 else if (const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock())
333 {
334 return needsRenaming(*interfaceBlock, recursive);
335 }
336 else
337 {
338 return false;
339 }
340 }
341
needsRenaming(const TVariable & variable,bool recursive) const342 bool needsRenaming(const TVariable &variable, bool recursive) const
343 {
344 return needsRenamingImpl(variable) ||
345 (recursive && needsRenaming(variable.getType(), true));
346 }
347
348 public:
Rewriter(TCompiler & compiler,IdGen & idGen,const std::set<ImmutableString> & keywords)349 Rewriter(TCompiler &compiler, IdGen &idGen, const std::set<ImmutableString> &keywords)
350 : TIntermRebuild(compiler, false, true), mKeywords(keywords), mIdGen(idGen)
351 {}
352
visitSymbolPost(TIntermSymbol & symbolNode)353 PostResult visitSymbolPost(TIntermSymbol &symbolNode) override
354 {
355 const TVariable &var = symbolNode.variable();
356 if (needsRenaming(var, true))
357 {
358 const TVariable &rVar = getRenamedOrOriginal(var);
359 return *new TIntermSymbol(&rVar);
360 }
361 return symbolNode;
362 }
363
visitFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)364 PostResult visitFunctionPrototype(TIntermFunctionPrototype &funcProtoNode)
365 {
366 const TFunction &func = *funcProtoNode.getFunction();
367 if (needsRenaming(func, true))
368 {
369 const TFunction &rFunc = getRenamedOrOriginal(func);
370 return *new TIntermFunctionPrototype(&rFunc);
371 }
372 return funcProtoNode;
373 }
374
visitDeclarationPost(TIntermDeclaration & declNode)375 PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
376 {
377 Declaration decl = ViewDeclaration(declNode);
378 const TVariable &var = decl.symbol.variable();
379 if (needsRenaming(var, true))
380 {
381 const TVariable &rVar = getRenamedOrOriginal(var);
382 return *new TIntermDeclaration(&rVar, decl.initExpr);
383 }
384 return declNode;
385 }
386
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)387 PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
388 {
389 TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
390 const TFunction &func = *funcProtoNode.getFunction();
391 if (needsRenaming(func, true))
392 {
393 const TFunction &rFunc = getRenamedOrOriginal(func);
394 auto *rFuncProtoNode = new TIntermFunctionPrototype(&rFunc);
395 return *new TIntermFunctionDefinition(rFuncProtoNode, funcDefNode.getBody());
396 }
397 return funcDefNode;
398 }
399
visitAggregatePost(TIntermAggregate & aggregateNode)400 PostResult visitAggregatePost(TIntermAggregate &aggregateNode) override
401 {
402 if (aggregateNode.isConstructor())
403 {
404 const TType &type = aggregateNode.getType();
405 if (needsRenaming(type, true))
406 {
407 const TType &rType = getRenamedOrOriginal(type);
408 return TIntermAggregate::CreateConstructor(rType, aggregateNode.getSequence());
409 }
410 }
411 else
412 {
413 const TFunction &func = *aggregateNode.getFunction();
414 if (needsRenaming(func, true))
415 {
416 const TFunction &rFunc = getRenamedOrOriginal(func);
417 switch (aggregateNode.getOp())
418 {
419 case TOperator::EOpCallFunctionInAST:
420 return TIntermAggregate::CreateFunctionCall(rFunc,
421 aggregateNode.getSequence());
422
423 case TOperator::EOpCallInternalRawFunction:
424 return TIntermAggregate::CreateRawFunctionCall(rFunc,
425 aggregateNode.getSequence());
426
427 default:
428 return TIntermAggregate::CreateBuiltInFunctionCall(
429 rFunc, aggregateNode.getSequence());
430 }
431 }
432 }
433 return aggregateNode;
434 }
435
predefineName(const ImmutableString name,std::string prePopulatedName)436 void predefineName(const ImmutableString name, std::string prePopulatedName)
437 {
438 mPredefinedNames[name] = prePopulatedName;
439 }
440 };
441
442 } // anonymous namespace
443
444 ////////////////////////////////////////////////////////////////////////////////
445
RewriteKeywords(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const std::set<ImmutableString> & keywords)446 bool sh::RewriteKeywords(TCompiler &compiler,
447 TIntermBlock &root,
448 IdGen &idGen,
449 const std::set<ImmutableString> &keywords)
450 {
451 Rewriter rewriter(compiler, idGen, keywords);
452 const auto &inputAttrs = compiler.getAttributes();
453 for (const auto &var : inputAttrs)
454 {
455 rewriter.predefineName(ImmutableString(var.name), var.mappedName);
456 }
457 if (!rewriter.rebuildRoot(root))
458 {
459 return false;
460 }
461 return true;
462 }
463