1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <numeric>
9 #include <unordered_map>
10 #include <unordered_set>
11 
12 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
13 
14 using namespace sh;
15 
16 ////////////////////////////////////////////////////////////////////////////////
17 
ViewDeclaration(TIntermDeclaration & declNode)18 Declaration sh::ViewDeclaration(TIntermDeclaration &declNode)
19 {
20     ASSERT(declNode.getChildCount() == 1);
21     TIntermNode *childNode = declNode.getChildNode(0);
22     ASSERT(childNode);
23     TIntermSymbol *symbolNode;
24     if ((symbolNode = childNode->getAsSymbolNode()))
25     {
26         return {*symbolNode, nullptr};
27     }
28     else
29     {
30         TIntermBinary *initNode = childNode->getAsBinaryNode();
31         ASSERT(initNode);
32         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
33         symbolNode = initNode->getLeft()->getAsSymbolNode();
34         ASSERT(symbolNode);
35         return {*symbolNode, initNode->getRight()};
36     }
37 }
38 
CreateStructTypeVariable(TSymbolTable & symbolTable,const TStructure & structure)39 const TVariable &sh::CreateStructTypeVariable(TSymbolTable &symbolTable,
40                                               const TStructure &structure)
41 {
42     auto *type = new TType(&structure, true);
43     auto *var  = new TVariable(&symbolTable, ImmutableString(""), type, SymbolType::Empty);
44     return *var;
45 }
46 
CreateInstanceVariable(TSymbolTable & symbolTable,const TStructure & structure,const Name & name,TQualifier qualifier,const TSpan<const unsigned int> * arraySizes)47 const TVariable &sh::CreateInstanceVariable(TSymbolTable &symbolTable,
48                                             const TStructure &structure,
49                                             const Name &name,
50                                             TQualifier qualifier,
51                                             const TSpan<const unsigned int> *arraySizes)
52 {
53     auto *type = new TType(&structure, false);
54     type->setQualifier(qualifier);
55     if (arraySizes)
56     {
57         type->makeArrays(*arraySizes);
58     }
59     auto *var = new TVariable(&symbolTable, name.rawName(), type, name.symbolType());
60     return *var;
61 }
62 
AcquireFunctionExtras(TFunction & dest,const TFunction & src)63 static void AcquireFunctionExtras(TFunction &dest, const TFunction &src)
64 {
65     if (src.isDefined())
66     {
67         dest.setDefined();
68     }
69 
70     if (src.hasPrototypeDeclaration())
71     {
72         dest.setHasPrototypeDeclaration();
73     }
74 }
75 
CloneSequenceAndPrepend(const TIntermSequence & seq,TIntermNode & node)76 TIntermSequence &sh::CloneSequenceAndPrepend(const TIntermSequence &seq, TIntermNode &node)
77 {
78     auto *newSeq = new TIntermSequence();
79     newSeq->push_back(&node);
80 
81     for (TIntermNode *oldNode : seq)
82     {
83         newSeq->push_back(oldNode);
84     }
85 
86     return *newSeq;
87 }
88 
AddParametersFrom(TFunction & dest,const TFunction & src)89 void sh::AddParametersFrom(TFunction &dest, const TFunction &src)
90 {
91     const size_t paramCount = src.getParamCount();
92     for (size_t i = 0; i < paramCount; ++i)
93     {
94         const TVariable *var = src.getParam(i);
95         dest.addParameter(var);
96     }
97 }
98 
CloneFunction(TSymbolTable & symbolTable,IdGen & idGen,const TFunction & oldFunc)99 const TFunction &sh::CloneFunction(TSymbolTable &symbolTable,
100                                    IdGen &idGen,
101                                    const TFunction &oldFunc)
102 {
103     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
104 
105     Name newName = idGen.createNewName(Name(oldFunc));
106 
107     auto &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
108                                    &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
109 
110     AcquireFunctionExtras(newFunc, oldFunc);
111     AddParametersFrom(newFunc, oldFunc);
112 
113     return newFunc;
114 }
115 
CloneFunctionAndPrependParam(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TVariable & newParam)116 const TFunction &sh::CloneFunctionAndPrependParam(TSymbolTable &symbolTable,
117                                                   IdGen *idGen,
118                                                   const TFunction &oldFunc,
119                                                   const TVariable &newParam)
120 {
121     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
122            oldFunc.symbolType() == SymbolType::AngleInternal);
123 
124     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
125 
126     auto &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
127                                    &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
128 
129     AcquireFunctionExtras(newFunc, oldFunc);
130     newFunc.addParameter(&newParam);
131     AddParametersFrom(newFunc, oldFunc);
132 
133     return newFunc;
134 }
135 
CloneFunctionAndAppendParams(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const std::vector<const TVariable * > & newParams)136 const TFunction &sh::CloneFunctionAndAppendParams(TSymbolTable &symbolTable,
137                                                   IdGen *idGen,
138                                                   const TFunction &oldFunc,
139                                                   const std::vector<const TVariable *> &newParams)
140 {
141     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined ||
142            oldFunc.symbolType() == SymbolType::AngleInternal);
143 
144     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
145 
146     auto &newFunc = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
147                                    &oldFunc.getReturnType(), oldFunc.isKnownToNotHaveSideEffects());
148 
149     AcquireFunctionExtras(newFunc, oldFunc);
150     AddParametersFrom(newFunc, oldFunc);
151     for (auto *param : newParams)
152     {
153         newFunc.addParameter(param);
154     }
155 
156     return newFunc;
157 }
158 
CloneFunctionAndChangeReturnType(TSymbolTable & symbolTable,IdGen * idGen,const TFunction & oldFunc,const TStructure & newReturn)159 const TFunction &sh::CloneFunctionAndChangeReturnType(TSymbolTable &symbolTable,
160                                                       IdGen *idGen,
161                                                       const TFunction &oldFunc,
162                                                       const TStructure &newReturn)
163 {
164     ASSERT(oldFunc.symbolType() == SymbolType::UserDefined);
165 
166     Name newName = idGen ? idGen->createNewName(Name(oldFunc)) : Name(oldFunc);
167 
168     auto *newReturnType = new TType(&newReturn, true);
169     auto &newFunc       = *new TFunction(&symbolTable, newName.rawName(), newName.symbolType(),
170                                    newReturnType, oldFunc.isKnownToNotHaveSideEffects());
171 
172     AcquireFunctionExtras(newFunc, oldFunc);
173     AddParametersFrom(newFunc, oldFunc);
174 
175     return newFunc;
176 }
177 
GetArg(const TIntermAggregate & call,size_t index)178 TIntermTyped &sh::GetArg(const TIntermAggregate &call, size_t index)
179 {
180     ASSERT(index < call.getChildCount());
181     auto *arg = call.getChildNode(index);
182     ASSERT(arg);
183     auto *targ = arg->getAsTyped();
184     ASSERT(targ);
185     return *targ;
186 }
187 
SetArg(TIntermAggregate & call,size_t index,TIntermTyped & arg)188 void sh::SetArg(TIntermAggregate &call, size_t index, TIntermTyped &arg)
189 {
190     ASSERT(index < call.getChildCount());
191     (*call.getSequence())[index] = &arg;
192 }
193 
GetFieldIndex(const TStructure & structure,const ImmutableString & fieldName)194 int sh::GetFieldIndex(const TStructure &structure, const ImmutableString &fieldName)
195 {
196     const TFieldList &fieldList = structure.fields();
197 
198     int i = 0;
199     for (TField *field : fieldList)
200     {
201         if (field->name() == fieldName)
202         {
203             return i;
204         }
205         ++i;
206     }
207 
208     return -1;
209 }
210 
AccessField(const TVariable & structInstanceVar,const ImmutableString & fieldName)211 TIntermBinary &sh::AccessField(const TVariable &structInstanceVar, const ImmutableString &fieldName)
212 {
213     return AccessField(*new TIntermSymbol(&structInstanceVar), fieldName);
214 }
215 
AccessField(TIntermTyped & object,const ImmutableString & fieldName)216 TIntermBinary &sh::AccessField(TIntermTyped &object, const ImmutableString &fieldName)
217 {
218     const TStructure *structure = object.getType().getStruct();
219     ASSERT(structure);
220 
221     const int index = GetFieldIndex(*structure, fieldName);
222     ASSERT(index >= 0);
223     return AccessFieldByIndex(object, index);
224 }
225 
AccessFieldByIndex(TIntermTyped & object,int index)226 TIntermBinary &sh::AccessFieldByIndex(TIntermTyped &object, int index)
227 {
228 #if defined(ANGLE_ENABLE_ASSERTS)
229     const TType &type = object.getType();
230     ASSERT(!type.isArray());
231     const TStructure *structure = type.getStruct();
232     ASSERT(structure);
233     ASSERT(0 <= index);
234     ASSERT(static_cast<size_t>(index) < structure->fields().size());
235 #endif
236 
237     return *new TIntermBinary(
238         TOperator::EOpIndexDirectStruct, &object,
239         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
240 }
241 
AccessIndex(TIntermTyped & indexableNode,int index)242 TIntermBinary &sh::AccessIndex(TIntermTyped &indexableNode, int index)
243 {
244 #if defined(ANGLE_ENABLE_ASSERTS)
245     const TType &type = indexableNode.getType();
246     ASSERT(type.isArray() || type.isVector() || type.isMatrix());
247 #endif
248 
249     auto *accessNode = new TIntermBinary(
250         TOperator::EOpIndexDirect, &indexableNode,
251         new TIntermConstantUnion(new TConstantUnion(index), *new TType(TBasicType::EbtInt)));
252     return *accessNode;
253 }
254 
AccessIndex(TIntermTyped & node,const int * index)255 TIntermTyped &sh::AccessIndex(TIntermTyped &node, const int *index)
256 {
257     if (index)
258     {
259         return AccessIndex(node, *index);
260     }
261     return node;
262 }
263 
SubVector(TIntermTyped & vectorNode,int begin,int end)264 TIntermTyped &sh::SubVector(TIntermTyped &vectorNode, int begin, int end)
265 {
266     ASSERT(vectorNode.getType().isVector());
267     ASSERT(0 <= begin);
268     ASSERT(end <= 4);
269     ASSERT(begin <= end);
270     if (begin == 0 && end == vectorNode.getType().getNominalSize())
271     {
272         return vectorNode;
273     }
274     TVector<int> offsets(static_cast<size_t>(end - begin));
275     std::iota(offsets.begin(), offsets.end(), begin);
276     auto *swizzle = new TIntermSwizzle(&vectorNode, offsets);
277     return *swizzle;
278 }
279 
IsScalarBasicType(const TType & type)280 bool sh::IsScalarBasicType(const TType &type)
281 {
282     if (!type.isScalar())
283     {
284         return false;
285     }
286     return HasScalarBasicType(type);
287 }
288 
IsVectorBasicType(const TType & type)289 bool sh::IsVectorBasicType(const TType &type)
290 {
291     if (!type.isVector())
292     {
293         return false;
294     }
295     return HasScalarBasicType(type);
296 }
297 
HasScalarBasicType(TBasicType type)298 bool sh::HasScalarBasicType(TBasicType type)
299 {
300     switch (type)
301     {
302         case TBasicType::EbtFloat:
303         case TBasicType::EbtDouble:
304         case TBasicType::EbtInt:
305         case TBasicType::EbtUInt:
306         case TBasicType::EbtBool:
307             return true;
308 
309         default:
310             return false;
311     }
312 }
313 
HasScalarBasicType(const TType & type)314 bool sh::HasScalarBasicType(const TType &type)
315 {
316     return HasScalarBasicType(type.getBasicType());
317 }
318 
InitType(TType & type)319 static void InitType(TType &type)
320 {
321     if (type.isArray())
322     {
323         auto sizes = type.getArraySizes();
324         type.toArrayBaseType();
325         type.makeArrays(sizes);
326     }
327 }
328 
CloneType(const TType & type)329 TType &sh::CloneType(const TType &type)
330 {
331     auto &clone = *new TType(type);
332     InitType(clone);
333     return clone;
334 }
335 
InnermostType(const TType & type)336 TType &sh::InnermostType(const TType &type)
337 {
338     auto &inner = *new TType(type);
339     inner.toArrayBaseType();
340     InitType(inner);
341     return inner;
342 }
343 
DropColumns(const TType & matrixType)344 TType &sh::DropColumns(const TType &matrixType)
345 {
346     ASSERT(matrixType.isMatrix());
347     ASSERT(HasScalarBasicType(matrixType));
348     const char *mangledName = nullptr;
349 
350     auto &vectorType =
351         *new TType(matrixType.getBasicType(), matrixType.getPrecision(), matrixType.getQualifier(),
352                    matrixType.getRows(), 1, matrixType.getArraySizes(), mangledName);
353     InitType(vectorType);
354     return vectorType;
355 }
356 
DropOuterDimension(const TType & arrayType)357 TType &sh::DropOuterDimension(const TType &arrayType)
358 {
359     ASSERT(arrayType.isArray());
360     const char *mangledName = nullptr;
361     const auto &arraySizes  = arrayType.getArraySizes();
362 
363     auto &innerType =
364         *new TType(arrayType.getBasicType(), arrayType.getPrecision(), arrayType.getQualifier(),
365                    arrayType.getNominalSize(), arrayType.getSecondarySize(),
366                    arraySizes.subspan(0, arraySizes.size() - 1), mangledName);
367     InitType(innerType);
368     return innerType;
369 }
370 
SetTypeDimsImpl(const TType & type,int primary,int secondary)371 static TType &SetTypeDimsImpl(const TType &type, int primary, int secondary)
372 {
373     ASSERT(1 < primary && primary <= 4);
374     ASSERT(1 <= secondary && secondary <= 4);
375     ASSERT(HasScalarBasicType(type));
376     const char *mangledName = nullptr;
377 
378     auto &newType = *new TType(type.getBasicType(), type.getPrecision(), type.getQualifier(),
379                                primary, secondary, type.getArraySizes(), mangledName);
380     InitType(newType);
381     return newType;
382 }
383 
SetVectorDim(const TType & type,int newDim)384 TType &sh::SetVectorDim(const TType &type, int newDim)
385 {
386     ASSERT(type.isRank0() || type.isVector());
387     return SetTypeDimsImpl(type, newDim, 1);
388 }
389 
SetMatrixRowDim(const TType & matrixType,int newDim)390 TType &sh::SetMatrixRowDim(const TType &matrixType, int newDim)
391 {
392     ASSERT(matrixType.isMatrix());
393     ASSERT(1 < newDim && newDim <= 4);
394     return SetTypeDimsImpl(matrixType, matrixType.getCols(), newDim);
395 }
396 
HasMatrixField(const TStructure & structure)397 bool sh::HasMatrixField(const TStructure &structure)
398 {
399     for (const TField *field : structure.fields())
400     {
401         const TType &type = *field->type();
402         if (type.isMatrix())
403         {
404             return true;
405         }
406     }
407     return false;
408 }
409 
HasArrayField(const TStructure & structure)410 bool sh::HasArrayField(const TStructure &structure)
411 {
412     for (const TField *field : structure.fields())
413     {
414         const TType &type = *field->type();
415         if (type.isArray())
416         {
417             return true;
418         }
419     }
420     return false;
421 }
422 
CoerceSimple(TBasicType toType,TIntermTyped & fromNode)423 TIntermTyped &sh::CoerceSimple(TBasicType toType, TIntermTyped &fromNode)
424 {
425     const TType &fromType = fromNode.getType();
426 
427     ASSERT(HasScalarBasicType(toType));
428     ASSERT(HasScalarBasicType(fromType));
429     ASSERT(!fromType.isArray());
430 
431     if (toType != fromType.getBasicType())
432     {
433         return *TIntermAggregate::CreateConstructor(
434             *new TType(toType, fromType.getNominalSize(), fromType.getSecondarySize()),
435             new TIntermSequence{&fromNode});
436     }
437     return fromNode;
438 }
439 
CoerceSimple(const TType & toType,TIntermTyped & fromNode)440 TIntermTyped &sh::CoerceSimple(const TType &toType, TIntermTyped &fromNode)
441 {
442     const TType &fromType = fromNode.getType();
443 
444     ASSERT(HasScalarBasicType(toType));
445     ASSERT(HasScalarBasicType(fromType));
446     ASSERT(toType.getNominalSize() == fromType.getNominalSize());
447     ASSERT(toType.getSecondarySize() == fromType.getSecondarySize());
448     ASSERT(!toType.isArray());
449     ASSERT(!fromType.isArray());
450 
451     if (toType.getBasicType() != fromType.getBasicType())
452     {
453         return *TIntermAggregate::CreateConstructor(toType, new TIntermSequence{&fromNode});
454     }
455     return fromNode;
456 }
457 
AsType(SymbolEnv & symbolEnv,const TType & toType,TIntermTyped & fromNode)458 TIntermTyped &sh::AsType(SymbolEnv &symbolEnv, const TType &toType, TIntermTyped &fromNode)
459 {
460     const TType &fromType = fromNode.getType();
461 
462     ASSERT(HasScalarBasicType(toType));
463     ASSERT(HasScalarBasicType(fromType));
464     ASSERT(!toType.isArray());
465     ASSERT(!fromType.isArray());
466 
467     if (toType == fromType)
468     {
469         return fromNode;
470     }
471     TemplateArg targ(toType);
472     return symbolEnv.callFunctionOverload(Name("as_type", SymbolType::BuiltIn), toType,
473                                           *new TIntermSequence{&fromNode}, 1, &targ);
474 }
475