1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cctype>
8 #include <cstring>
9 #include <limits>
10 #include <map>
11 #include <unordered_map>
12 #include <unordered_set>
13 
14 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
15 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
16 #include "compiler/translator/TranslatorMetalDirect/RewriteKeywords.h"
17 
18 using namespace sh;
19 
20 ////////////////////////////////////////////////////////////////////////////////
21 
22 namespace
23 {
24 
25 template <typename T>
26 using Remapping = std::unordered_map<const T *, const T *>;
27 
28 class Rewriter : public TIntermRebuild
29 {
30   private:
31     const std::set<ImmutableString> &mKeywords;
32     IdGen &mIdGen;
33     Remapping<TField> modifiedFields;
34     Remapping<TFieldList> mFieldLists;
35     Remapping<TFunction> mFunctions;
36     Remapping<TInterfaceBlock> mInterfaceBlocks;
37     Remapping<TStructure> mStructures;
38     Remapping<TVariable> mVariables;
39     std::map<ImmutableString, std::string> mPredefinedNames;
40     std::string mNewNameBuffer;
41 
42   private:
43     template <typename T>
maybeCreateNewName(T const & object)44     ImmutableString maybeCreateNewName(T const &object)
45     {
46         if (needsRenaming(object, false))
47         {
48             auto it = mPredefinedNames.find(Name(object).rawName());
49             if (it != mPredefinedNames.end())
50             {
51                 return ImmutableString(it->second);
52             }
53             return mIdGen.createNewName(Name(object)).rawName();
54         }
55         return Name(object).rawName();
56     }
57 
createRenamed(const TField & field)58     const TField *createRenamed(const TField &field)
59     {
60         auto *renamed =
61             new TField(const_cast<TType *>(&getRenamedOrOriginal(*field.type())),
62                        maybeCreateNewName(field), field.line(), SymbolType::AngleInternal);
63 
64         return renamed;
65     }
66 
createRenamed(const TFieldList & fieldList)67     const TFieldList *createRenamed(const TFieldList &fieldList)
68     {
69         auto *renamed = new TFieldList();
70         for (const TField *field : fieldList)
71         {
72             renamed->push_back(const_cast<TField *>(&getRenamedOrOriginal(*field)));
73         }
74         return renamed;
75     }
76 
createRenamed(const TFunction & function)77     const TFunction *createRenamed(const TFunction &function)
78     {
79         auto *renamed =
80             new TFunction(&mSymbolTable, maybeCreateNewName(function), SymbolType::AngleInternal,
81                           &getRenamedOrOriginal(function.getReturnType()),
82                           function.isKnownToNotHaveSideEffects());
83 
84         const size_t paramCount = function.getParamCount();
85         for (size_t i = 0; i < paramCount; ++i)
86         {
87             const TVariable &param = *function.getParam(i);
88             renamed->addParameter(&getRenamedOrOriginal(param));
89         }
90 
91         if (function.isDefined())
92         {
93             renamed->setDefined();
94         }
95 
96         if (function.hasPrototypeDeclaration())
97         {
98             renamed->setHasPrototypeDeclaration();
99         }
100 
101         return renamed;
102     }
103 
createRenamed(const TInterfaceBlock & interfaceBlock)104     const TInterfaceBlock *createRenamed(const TInterfaceBlock &interfaceBlock)
105     {
106         TLayoutQualifier layoutQualifier = TLayoutQualifier::Create();
107         layoutQualifier.blockStorage     = interfaceBlock.blockStorage();
108         layoutQualifier.binding          = interfaceBlock.blockBinding();
109 
110         auto *renamed =
111             new TInterfaceBlock(&mSymbolTable, maybeCreateNewName(interfaceBlock),
112                                 &getRenamedOrOriginal(interfaceBlock.fields()), layoutQualifier,
113                                 SymbolType::AngleInternal, interfaceBlock.extensions());
114 
115         return renamed;
116     }
117 
createRenamed(const TStructure & structure)118     const TStructure *createRenamed(const TStructure &structure)
119     {
120         auto *renamed =
121             new TStructure(&mSymbolTable, maybeCreateNewName(structure),
122                            &getRenamedOrOriginal(structure.fields()), SymbolType::AngleInternal);
123 
124         renamed->setAtGlobalScope(structure.atGlobalScope());
125 
126         return renamed;
127     }
128 
createRenamed(const TType & type)129     const TType *createRenamed(const TType &type)
130     {
131         TType *renamed;
132 
133         if (const TStructure *structure = type.getStruct())
134         {
135             renamed = new TType(&getRenamedOrOriginal(*structure), type.isStructSpecifier());
136         }
137         else if (const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock())
138         {
139             renamed = new TType(&getRenamedOrOriginal(*interfaceBlock), type.getQualifier(),
140                                 type.getLayoutQualifier());
141         }
142         else
143         {
144             UNREACHABLE();  // Can't rename built-in types.
145             renamed = nullptr;
146         }
147 
148         if (type.isArray())
149         {
150             renamed->makeArrays(type.getArraySizes());
151         }
152         renamed->setPrecise(type.isPrecise());
153         renamed->setInvariant(type.isInvariant());
154         renamed->setMemoryQualifier(type.getMemoryQualifier());
155         renamed->setLayoutQualifier(type.getLayoutQualifier());
156 
157         return renamed;
158     }
159 
createRenamed(const TVariable & variable)160     const TVariable *createRenamed(const TVariable &variable)
161     {
162         auto *renamed = new TVariable(&mSymbolTable, maybeCreateNewName(variable),
163                                       &getRenamedOrOriginal(variable.getType()),
164                                       SymbolType::AngleInternal, variable.extensions());
165 
166         return renamed;
167     }
168 
169     template <typename T>
tryGetRenamedImpl(const T & object,Remapping<T> * remapping)170     const T *tryGetRenamedImpl(const T &object, Remapping<T> *remapping)
171     {
172         if (!needsRenaming(object, true))
173         {
174             return nullptr;
175         }
176 
177         if (remapping)
178         {
179             auto it = remapping->find(&object);
180             if (it != remapping->end())
181             {
182                 return it->second;
183             }
184         }
185 
186         const T *renamedObject = createRenamed(object);
187 
188         if (remapping)
189         {
190             (*remapping)[&object] = renamedObject;
191         }
192 
193         return renamedObject;
194     }
195 
tryGetRenamed(const TField & field)196     const TField *tryGetRenamed(const TField &field)
197     {
198         return tryGetRenamedImpl(field, &modifiedFields);
199     }
200 
tryGetRenamed(const TFieldList & fieldList)201     const TFieldList *tryGetRenamed(const TFieldList &fieldList)
202     {
203         return tryGetRenamedImpl(fieldList, &mFieldLists);
204     }
205 
tryGetRenamed(const TFunction & func)206     const TFunction *tryGetRenamed(const TFunction &func)
207     {
208         return tryGetRenamedImpl(func, &mFunctions);
209     }
210 
tryGetRenamed(const TInterfaceBlock & interfaceBlock)211     const TInterfaceBlock *tryGetRenamed(const TInterfaceBlock &interfaceBlock)
212     {
213         return tryGetRenamedImpl(interfaceBlock, &mInterfaceBlocks);
214     }
215 
tryGetRenamed(const TStructure & structure)216     const TStructure *tryGetRenamed(const TStructure &structure)
217     {
218         return tryGetRenamedImpl(structure, &mStructures);
219     }
220 
tryGetRenamed(const TType & type)221     const TType *tryGetRenamed(const TType &type)
222     {
223         return tryGetRenamedImpl(type, static_cast<Remapping<TType> *>(nullptr));
224     }
225 
tryGetRenamed(const TVariable & variable)226     const TVariable *tryGetRenamed(const TVariable &variable)
227     {
228         return tryGetRenamedImpl(variable, &mVariables);
229     }
230 
231     template <typename T>
getRenamedOrOriginal(const T & object)232     const T &getRenamedOrOriginal(const T &object)
233     {
234         const T *renamed = tryGetRenamed(object);
235         if (renamed)
236         {
237             return *renamed;
238         }
239         return object;
240     }
241 
242     template <typename T>
needsRenamingImpl(const T & object) const243     bool needsRenamingImpl(const T &object) const
244     {
245         const SymbolType symbolType = object.symbolType();
246         switch (symbolType)
247         {
248             case SymbolType::BuiltIn:
249             case SymbolType::AngleInternal:
250             case SymbolType::Empty:
251                 return false;
252 
253             case SymbolType::UserDefined:
254                 break;
255         }
256 
257         const ImmutableString name = Name(object).rawName();
258         if (mKeywords.find(name) != mKeywords.end())
259         {
260             return true;
261         }
262 
263         if (name.beginsWith(kAngleInternalPrefix))
264         {
265             return true;
266         }
267 
268         return false;
269     }
270 
needsRenaming(const TField & field,bool recursive) const271     bool needsRenaming(const TField &field, bool recursive) const
272     {
273         return needsRenamingImpl(field) || (recursive && needsRenaming(*field.type(), true));
274     }
275 
needsRenaming(const TFieldList & fieldList,bool recursive) const276     bool needsRenaming(const TFieldList &fieldList, bool recursive) const
277     {
278         ASSERT(recursive);
279         for (const TField *field : fieldList)
280         {
281             if (needsRenaming(*field, true))
282             {
283                 return true;
284             }
285         }
286         return false;
287     }
288 
needsRenaming(const TFunction & function,bool recursive) const289     bool needsRenaming(const TFunction &function, bool recursive) const
290     {
291         if (needsRenamingImpl(function))
292         {
293             return true;
294         }
295 
296         if (!recursive)
297         {
298             return false;
299         }
300 
301         const size_t paramCount = function.getParamCount();
302         for (size_t i = 0; i < paramCount; ++i)
303         {
304             const TVariable &param = *function.getParam(i);
305             if (needsRenaming(param, true))
306             {
307                 return true;
308             }
309         }
310 
311         return false;
312     }
313 
needsRenaming(const TInterfaceBlock & interfaceBlock,bool recursive) const314     bool needsRenaming(const TInterfaceBlock &interfaceBlock, bool recursive) const
315     {
316         return needsRenamingImpl(interfaceBlock) ||
317                (recursive && needsRenaming(interfaceBlock.fields(), true));
318     }
319 
needsRenaming(const TStructure & structure,bool recursive) const320     bool needsRenaming(const TStructure &structure, bool recursive) const
321     {
322         return needsRenamingImpl(structure) ||
323                (recursive && needsRenaming(structure.fields(), true));
324     }
325 
needsRenaming(const TType & type,bool recursive) const326     bool needsRenaming(const TType &type, bool recursive) const
327     {
328         if (const TStructure *structure = type.getStruct())
329         {
330             return needsRenaming(*structure, recursive);
331         }
332         else if (const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock())
333         {
334             return needsRenaming(*interfaceBlock, recursive);
335         }
336         else
337         {
338             return false;
339         }
340     }
341 
needsRenaming(const TVariable & variable,bool recursive) const342     bool needsRenaming(const TVariable &variable, bool recursive) const
343     {
344         return needsRenamingImpl(variable) ||
345                (recursive && needsRenaming(variable.getType(), true));
346     }
347 
348   public:
Rewriter(TCompiler & compiler,IdGen & idGen,const std::set<ImmutableString> & keywords)349     Rewriter(TCompiler &compiler, IdGen &idGen, const std::set<ImmutableString> &keywords)
350         : TIntermRebuild(compiler, false, true), mKeywords(keywords), mIdGen(idGen)
351     {}
352 
visitSymbolPost(TIntermSymbol & symbolNode)353     PostResult visitSymbolPost(TIntermSymbol &symbolNode) override
354     {
355         const TVariable &var = symbolNode.variable();
356         if (needsRenaming(var, true))
357         {
358             const TVariable &rVar = getRenamedOrOriginal(var);
359             return *new TIntermSymbol(&rVar);
360         }
361         return symbolNode;
362     }
363 
visitFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)364     PostResult visitFunctionPrototype(TIntermFunctionPrototype &funcProtoNode)
365     {
366         const TFunction &func = *funcProtoNode.getFunction();
367         if (needsRenaming(func, true))
368         {
369             const TFunction &rFunc = getRenamedOrOriginal(func);
370             return *new TIntermFunctionPrototype(&rFunc);
371         }
372         return funcProtoNode;
373     }
374 
visitDeclarationPost(TIntermDeclaration & declNode)375     PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
376     {
377         Declaration decl     = ViewDeclaration(declNode);
378         const TVariable &var = decl.symbol.variable();
379         if (needsRenaming(var, true))
380         {
381             const TVariable &rVar = getRenamedOrOriginal(var);
382             return *new TIntermDeclaration(&rVar, decl.initExpr);
383         }
384         return declNode;
385     }
386 
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)387     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
388     {
389         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
390         const TFunction &func                   = *funcProtoNode.getFunction();
391         if (needsRenaming(func, true))
392         {
393             const TFunction &rFunc = getRenamedOrOriginal(func);
394             auto *rFuncProtoNode   = new TIntermFunctionPrototype(&rFunc);
395             return *new TIntermFunctionDefinition(rFuncProtoNode, funcDefNode.getBody());
396         }
397         return funcDefNode;
398     }
399 
visitAggregatePost(TIntermAggregate & aggregateNode)400     PostResult visitAggregatePost(TIntermAggregate &aggregateNode) override
401     {
402         if (aggregateNode.isConstructor())
403         {
404             const TType &type = aggregateNode.getType();
405             if (needsRenaming(type, true))
406             {
407                 const TType &rType = getRenamedOrOriginal(type);
408                 return TIntermAggregate::CreateConstructor(rType, aggregateNode.getSequence());
409             }
410         }
411         else
412         {
413             const TFunction &func = *aggregateNode.getFunction();
414             if (needsRenaming(func, true))
415             {
416                 const TFunction &rFunc = getRenamedOrOriginal(func);
417                 switch (aggregateNode.getOp())
418                 {
419                     case TOperator::EOpCallFunctionInAST:
420                         return TIntermAggregate::CreateFunctionCall(rFunc,
421                                                                     aggregateNode.getSequence());
422 
423                     case TOperator::EOpCallInternalRawFunction:
424                         return TIntermAggregate::CreateRawFunctionCall(rFunc,
425                                                                        aggregateNode.getSequence());
426 
427                     default:
428                         return TIntermAggregate::CreateBuiltInFunctionCall(
429                             rFunc, aggregateNode.getSequence());
430                 }
431             }
432         }
433         return aggregateNode;
434     }
435 
predefineName(const ImmutableString name,std::string prePopulatedName)436     void predefineName(const ImmutableString name, std::string prePopulatedName)
437     {
438         mPredefinedNames[name] = prePopulatedName;
439     }
440 };
441 
442 }  // anonymous namespace
443 
444 ////////////////////////////////////////////////////////////////////////////////
445 
RewriteKeywords(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const std::set<ImmutableString> & keywords)446 bool sh::RewriteKeywords(TCompiler &compiler,
447                          TIntermBlock &root,
448                          IdGen &idGen,
449                          const std::set<ImmutableString> &keywords)
450 {
451     Rewriter rewriter(compiler, idGen, keywords);
452     const auto &inputAttrs = compiler.getAttributes();
453     for (const auto &var : inputAttrs)
454     {
455         rewriter.predefineName(ImmutableString(var.name), var.mappedName);
456     }
457     if (!rewriter.rebuildRoot(root))
458     {
459         return false;
460     }
461     return true;
462 }
463