1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <limits>
9 
10 #include "compiler/translator/ImmutableStringBuilder.h"
11 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
12 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
13 #include "compiler/translator/TranslatorMetalDirect/SymbolEnv.h"
14 
15 using namespace sh;
16 
17 ////////////////////////////////////////////////////////////////////////////////
18 
19 constexpr AddressSpace kAddressSpaces[] = {
20     AddressSpace::Constant,
21     AddressSpace::Device,
22     AddressSpace::Thread,
23 };
24 
toString(AddressSpace space)25 char const *sh::toString(AddressSpace space)
26 {
27     switch (space)
28     {
29         case AddressSpace::Constant:
30             return "constant";
31         case AddressSpace::Device:
32             return "device";
33         case AddressSpace::Thread:
34             return "thread";
35     }
36 }
37 
38 ////////////////////////////////////////////////////////////////////////////////
39 
40 using NameToStruct = std::map<Name, const TStructure *>;
41 
42 class StructFinder : TIntermRebuild
43 {
44     NameToStruct nameToStruct;
45 
StructFinder(TCompiler & compiler)46     StructFinder(TCompiler &compiler) : TIntermRebuild(compiler, true, false) {}
47 
visitDeclarationPre(TIntermDeclaration & node)48     PreResult visitDeclarationPre(TIntermDeclaration &node) override
49     {
50         Declaration decl     = ViewDeclaration(node);
51         const TVariable &var = decl.symbol.variable();
52         const TType &type    = var.getType();
53 
54         if (var.symbolType() == SymbolType::Empty && type.isStructSpecifier())
55         {
56             const TStructure *s = type.getStruct();
57             ASSERT(s);
58             const Name name(*s);
59             const TStructure *&z = nameToStruct[name];
60             ASSERT(!z);
61             z = s;
62         }
63 
64         return node;
65     }
66 
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)67     PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
68     {
69         return {node, VisitBits::Neither};
70     }
71 
72   public:
FindStructs(TCompiler & compiler,TIntermBlock & root)73     static NameToStruct FindStructs(TCompiler &compiler, TIntermBlock &root)
74     {
75         StructFinder finder(compiler);
76         if (!finder.rebuildRoot(root))
77         {
78             UNREACHABLE();
79         }
80         return std::move(finder.nameToStruct);
81     }
82 };
83 
84 ////////////////////////////////////////////////////////////////////////////////
85 
TemplateArg(bool value)86 TemplateArg::TemplateArg(bool value) : mKind(Kind::Bool), mValue(value) {}
87 
TemplateArg(int value)88 TemplateArg::TemplateArg(int value) : mKind(Kind::Int), mValue(value) {}
89 
TemplateArg(unsigned value)90 TemplateArg::TemplateArg(unsigned value) : mKind(Kind::UInt), mValue(value) {}
91 
TemplateArg(const TType & value)92 TemplateArg::TemplateArg(const TType &value) : mKind(Kind::Type), mValue(value) {}
93 
operator ==(const TemplateArg & other) const94 bool TemplateArg::operator==(const TemplateArg &other) const
95 {
96     if (mKind != other.mKind)
97     {
98         return false;
99     }
100 
101     switch (mKind)
102     {
103         case Kind::Bool:
104             return mValue.b == other.mValue.b;
105         case Kind::Int:
106             return mValue.i == other.mValue.i;
107         case Kind::UInt:
108             return mValue.u == other.mValue.u;
109         case Kind::Type:
110             return *mValue.t == *other.mValue.t;
111     }
112 }
113 
operator <(const TemplateArg & other) const114 bool TemplateArg::operator<(const TemplateArg &other) const
115 {
116     if (mKind < other.mKind)
117     {
118         return true;
119     }
120 
121     if (mKind > other.mKind)
122     {
123         return false;
124     }
125 
126     switch (mKind)
127     {
128         case Kind::Bool:
129             return mValue.b < other.mValue.b;
130         case Kind::Int:
131             return mValue.i < other.mValue.i;
132         case Kind::UInt:
133             return mValue.u < other.mValue.u;
134         case Kind::Type:
135             return *mValue.t < *other.mValue.t;
136     }
137 }
138 
139 ////////////////////////////////////////////////////////////////////////////////
140 
operator ==(const TemplateName & other) const141 bool SymbolEnv::TemplateName::operator==(const TemplateName &other) const
142 {
143     return baseName == other.baseName && templateArgs == other.templateArgs;
144 }
145 
operator <(const TemplateName & other) const146 bool SymbolEnv::TemplateName::operator<(const TemplateName &other) const
147 {
148     if (baseName < other.baseName)
149     {
150         return true;
151     }
152     if (other.baseName < baseName)
153     {
154         return false;
155     }
156     return templateArgs < other.templateArgs;
157 }
158 
empty() const159 bool SymbolEnv::TemplateName::empty() const
160 {
161     return baseName.empty() && templateArgs.empty();
162 }
163 
clear()164 void SymbolEnv::TemplateName::clear()
165 {
166     baseName = Name();
167     templateArgs.clear();
168 }
169 
fullName(std::string & buffer) const170 Name SymbolEnv::TemplateName::fullName(std::string &buffer) const
171 {
172     ASSERT(buffer.empty());
173 
174     if (templateArgs.empty())
175     {
176         return baseName;
177     }
178 
179     static constexpr size_t n = std::max({
180         std::numeric_limits<unsigned>::digits10,  //
181         std::numeric_limits<int>::digits10,       //
182         5,                                        // max_length("true", "false")
183     });
184 
185     buffer.reserve(baseName.rawName().length() + (n + 2) * templateArgs.size() + 1);
186     buffer += baseName.rawName().data();
187 
188     if (!templateArgs.empty())
189     {
190         buffer += "<";
191 
192         bool first = true;
193         char argBuffer[n + 1];
194         for (const TemplateArg &arg : templateArgs)
195         {
196             if (first)
197             {
198                 first = false;
199             }
200             else
201             {
202                 buffer += ", ";
203             }
204 
205             const TemplateArg::Value value = arg.value();
206             const TemplateArg::Kind kind   = arg.kind();
207             switch (kind)
208             {
209                 case TemplateArg::Kind::Bool:
210                     if (value.b)
211                     {
212                         buffer += "true";
213                     }
214                     else
215                     {
216                         buffer += "false";
217                     }
218                     break;
219 
220                 case TemplateArg::Kind::Int:
221                     sprintf(argBuffer, "%i", value.i);
222                     buffer += argBuffer;
223                     break;
224 
225                 case TemplateArg::Kind::UInt:
226                     sprintf(argBuffer, "%u", value.u);
227                     buffer += argBuffer;
228                     break;
229 
230                 case TemplateArg::Kind::Type:
231                 {
232                     const TType &type = *value.t;
233                     if (const TStructure *s = type.getStruct())
234                     {
235                         buffer += s->name().data();
236                     }
237                     else if (HasScalarBasicType(type))
238                     {
239                         ASSERT(!type.isArray());  // TODO
240                         buffer += type.getBasicString();
241                         if (type.isVector())
242                         {
243                             sprintf(argBuffer, "%i", type.getNominalSize());
244                             buffer += argBuffer;
245                         }
246                         else if (type.isMatrix())
247                         {
248                             sprintf(argBuffer, "%i", type.getCols());
249                             buffer += argBuffer;
250                             buffer += "x";
251                             sprintf(argBuffer, "%i", type.getRows());
252                             buffer += argBuffer;
253                         }
254                     }
255                 }
256                 break;
257             }
258         }
259 
260         buffer += ">";
261     }
262 
263     const ImmutableString name(buffer);
264     buffer.clear();
265 
266     return Name(name, baseName.symbolType());
267 }
268 
assign(const Name & name,size_t argCount,const TemplateArg * args)269 void SymbolEnv::TemplateName::assign(const Name &name, size_t argCount, const TemplateArg *args)
270 {
271     baseName = name;
272     templateArgs.clear();
273     for (size_t i = 0; i < argCount; ++i)
274     {
275         templateArgs.push_back(args[i]);
276     }
277 }
278 
279 ////////////////////////////////////////////////////////////////////////////////
280 
SymbolEnv(TCompiler & compiler,TIntermBlock & root)281 SymbolEnv::SymbolEnv(TCompiler &compiler, TIntermBlock &root)
282     : mSymbolTable(compiler.getSymbolTable()),
283       mNameToStruct(StructFinder::FindStructs(compiler, root))
284 {}
285 
remap(const TStructure & s) const286 const TStructure &SymbolEnv::remap(const TStructure &s) const
287 {
288     const Name name(s);
289     auto iter = mNameToStruct.find(name);
290     if (iter == mNameToStruct.end())
291     {
292         return s;
293     }
294     const TStructure &z = *iter->second;
295     return z;
296 }
297 
remap(const TStructure * s) const298 const TStructure *SymbolEnv::remap(const TStructure *s) const
299 {
300     if (s)
301     {
302         return &remap(*s);
303     }
304     return nullptr;
305 }
306 
getFunctionOverloadImpl()307 const TFunction &SymbolEnv::getFunctionOverloadImpl()
308 {
309     ASSERT(!mReusableSigBuffer.empty());
310 
311     SigToFunc &sigToFunc = mOverloads[mReusableTemplateNameBuffer];
312     TFunction *&func     = sigToFunc[mReusableSigBuffer];
313 
314     if (!func)
315     {
316         const TType &returnType = mReusableSigBuffer.back();
317         mReusableSigBuffer.pop_back();
318 
319         const Name name = mReusableTemplateNameBuffer.fullName(mReusableStringBuffer);
320 
321         func = new TFunction(&mSymbolTable, name.rawName(), name.symbolType(), &returnType, false);
322         for (const TType &paramType : mReusableSigBuffer)
323         {
324             func->addParameter(
325                 new TVariable(&mSymbolTable, kEmptyImmutableString, &paramType, SymbolType::Empty));
326         }
327     }
328 
329     mReusableSigBuffer.clear();
330     mReusableTemplateNameBuffer.clear();
331 
332     return *func;
333 }
334 
getFunctionOverload(const Name & name,const TType & returnType,size_t paramCount,const TType ** paramTypes,size_t templateArgCount,const TemplateArg * templateArgs)335 const TFunction &SymbolEnv::getFunctionOverload(const Name &name,
336                                                 const TType &returnType,
337                                                 size_t paramCount,
338                                                 const TType **paramTypes,
339                                                 size_t templateArgCount,
340                                                 const TemplateArg *templateArgs)
341 {
342     ASSERT(mReusableSigBuffer.empty());
343     ASSERT(mReusableTemplateNameBuffer.empty());
344 
345     for (size_t i = 0; i < paramCount; ++i)
346     {
347         mReusableSigBuffer.push_back(*paramTypes[i]);
348     }
349     mReusableSigBuffer.push_back(returnType);
350     mReusableTemplateNameBuffer.assign(name, templateArgCount, templateArgs);
351     return getFunctionOverloadImpl();
352 }
353 
callFunctionOverload(const Name & name,const TType & returnType,TIntermSequence & args,size_t templateArgCount,const TemplateArg * templateArgs)354 TIntermAggregate &SymbolEnv::callFunctionOverload(const Name &name,
355                                                   const TType &returnType,
356                                                   TIntermSequence &args,
357                                                   size_t templateArgCount,
358                                                   const TemplateArg *templateArgs)
359 {
360     ASSERT(mReusableSigBuffer.empty());
361     ASSERT(mReusableTemplateNameBuffer.empty());
362 
363     for (TIntermNode *arg : args)
364     {
365         TIntermTyped *targ = arg->getAsTyped();
366         ASSERT(targ);
367         mReusableSigBuffer.push_back(targ->getType());
368     }
369     mReusableSigBuffer.push_back(returnType);
370     mReusableTemplateNameBuffer.assign(name, templateArgCount, templateArgs);
371     const TFunction &func = getFunctionOverloadImpl();
372     return *TIntermAggregate::CreateRawFunctionCall(func, &args);
373 }
374 
newStructure(const Name & name,TFieldList & fields)375 const TStructure &SymbolEnv::newStructure(const Name &name, TFieldList &fields)
376 {
377     ASSERT(name.symbolType() == SymbolType::AngleInternal);
378 
379     TStructure *&s = mAngleStructs[name.rawName()];
380     ASSERT(!s);
381     s = new TStructure(&mSymbolTable, name.rawName(), &fields, name.symbolType());
382     return *s;
383 }
384 
getTextureEnv(TBasicType samplerType)385 const TStructure &SymbolEnv::getTextureEnv(TBasicType samplerType)
386 {
387     ASSERT(IsSampler(samplerType));
388     const TStructure *&env = mTextureEnvs[samplerType];
389     if (env == nullptr)
390     {
391         auto *textureType = new TType(samplerType);
392         auto *texture     = new TField(textureType, ImmutableString("texture"), kNoSourceLoc,
393                                    SymbolType::UserDefined);
394         markAsPointer(*texture, AddressSpace::Thread);
395 
396         auto *sampler =
397             new TField(new TType(&getSamplerStruct(), false), ImmutableString("sampler"),
398                        kNoSourceLoc, SymbolType::UserDefined);
399         markAsPointer(*sampler, AddressSpace::Thread);
400 
401         std::string envName;
402         envName += "TextureEnv<";
403         envName += GetTextureTypeName(samplerType).rawName().data();
404         envName += ">";
405 
406         env = &newStructure(Name(envName, SymbolType::AngleInternal),
407                             *new TFieldList{texture, sampler});
408     }
409     return *env;
410 }
411 
getSamplerStruct()412 const TStructure &SymbolEnv::getSamplerStruct()
413 {
414     if (!mSampler)
415     {
416         mSampler = new TStructure(&mSymbolTable, ImmutableString("metal::sampler"),
417                                   new TFieldList(), SymbolType::UserDefined);
418     }
419     return *mSampler;
420 }
421 
markSpace(VarField x,AddressSpace space,std::unordered_map<VarField,AddressSpace> & map)422 void SymbolEnv::markSpace(VarField x,
423                           AddressSpace space,
424                           std::unordered_map<VarField, AddressSpace> &map)
425 {
426     // It is in principle permissible to have references to pointers or multiple pointers, but this
427     // is not required for now and would require code changes to get right.
428     ASSERT(!isPointer(x));
429     ASSERT(!isReference(x));
430 
431     map[x] = space;
432 }
433 
removeSpace(VarField x,std::unordered_map<VarField,AddressSpace> & map)434 void SymbolEnv::removeSpace(VarField x, std::unordered_map<VarField, AddressSpace> &map)
435 {
436     // It is in principle permissible to have references to pointers or multiple pointers, but this
437     // is not required for now and would require code changes to get right.
438     map.erase(x);
439 }
440 
isSpace(VarField x,const std::unordered_map<VarField,AddressSpace> & map) const441 const AddressSpace *SymbolEnv::isSpace(VarField x,
442                                        const std::unordered_map<VarField, AddressSpace> &map) const
443 {
444     const auto iter = map.find(x);
445     if (iter == map.end())
446     {
447         return nullptr;
448     }
449     const AddressSpace space = iter->second;
450     const auto index         = static_cast<std::underlying_type_t<AddressSpace>>(space);
451     return &kAddressSpaces[index];
452 }
453 
markAsPointer(VarField x,AddressSpace space)454 void SymbolEnv::markAsPointer(VarField x, AddressSpace space)
455 {
456     return markSpace(x, space, mPointers);
457 }
458 
removePointer(VarField x)459 void SymbolEnv::removePointer(VarField x)
460 {
461     return removeSpace(x, mPointers);
462 }
463 
markAsReference(VarField x,AddressSpace space)464 void SymbolEnv::markAsReference(VarField x, AddressSpace space)
465 {
466     return markSpace(x, space, mReferences);
467 }
468 
isPointer(VarField x) const469 const AddressSpace *SymbolEnv::isPointer(VarField x) const
470 {
471     return isSpace(x, mPointers);
472 }
473 
isReference(VarField x) const474 const AddressSpace *SymbolEnv::isReference(VarField x) const
475 {
476     return isSpace(x, mReferences);
477 }
478 
markAsPacked(const TField & field)479 void SymbolEnv::markAsPacked(const TField &field)
480 {
481     mPackedFields.insert(&field);
482 }
483 
isPacked(const TField & field) const484 bool SymbolEnv::isPacked(const TField &field) const
485 {
486     return mPackedFields.find(&field) != mPackedFields.end();
487 }
488 
markAsUBO(VarField x)489 void SymbolEnv::markAsUBO(VarField x)
490 {
491     mUboFields.insert(x);
492 }
493 
isUBO(VarField x) const494 bool SymbolEnv::isUBO(VarField x) const
495 {
496     return mUboFields.find(x) != mUboFields.end();
497 }
498 
GetTextureBasicType(TBasicType basicType)499 static TBasicType GetTextureBasicType(TBasicType basicType)
500 {
501     ASSERT(IsSampler(basicType));
502 
503     switch (basicType)
504     {
505         case EbtSampler2D:
506         case EbtSampler3D:
507         case EbtSamplerCube:
508         case EbtSampler2DArray:
509         case EbtSamplerExternalOES:
510         case EbtSamplerExternal2DY2YEXT:
511         case EbtSampler2DRect:
512         case EbtSampler2DMS:
513         case EbtSampler2DMSArray:
514         case EbtSamplerVideoWEBGL:
515         case EbtSampler2DShadow:
516         case EbtSamplerCubeShadow:
517         case EbtSampler2DArrayShadow:
518         case EbtSampler1D:
519         case EbtSampler1DArray:
520         case EbtSampler1DArrayShadow:
521         case EbtSamplerBuffer:
522         case EbtSamplerCubeArray:
523         case EbtSamplerCubeArrayShadow:
524         case EbtSampler1DShadow:
525         case EbtSampler2DRectShadow:
526             return TBasicType::EbtFloat;
527 
528         case EbtISampler2D:
529         case EbtISampler3D:
530         case EbtISamplerCube:
531         case EbtISampler2DArray:
532         case EbtISampler2DMS:
533         case EbtISampler2DMSArray:
534         case EbtISampler1D:
535         case EbtISampler1DArray:
536         case EbtISampler2DRect:
537         case EbtISamplerBuffer:
538         case EbtISamplerCubeArray:
539             return TBasicType::EbtInt;
540 
541         case EbtUSampler2D:
542         case EbtUSampler3D:
543         case EbtUSamplerCube:
544         case EbtUSampler2DArray:
545         case EbtUSampler2DMS:
546         case EbtUSampler2DMSArray:
547         case EbtUSampler1D:
548         case EbtUSampler1DArray:
549         case EbtUSampler2DRect:
550         case EbtUSamplerBuffer:
551         case EbtUSamplerCubeArray:
552             return TBasicType::EbtUInt;
553 
554         default:
555             UNREACHABLE();
556             return TBasicType::EbtVoid;
557     }
558 }
559 
GetTextureTypeName(TBasicType samplerType)560 Name sh::GetTextureTypeName(TBasicType samplerType)
561 {
562     ASSERT(IsSampler(samplerType));
563 
564     const TBasicType textureType = GetTextureBasicType(samplerType);
565     const char *name;
566 
567 #define HANDLE_TEXTURE_NAME(baseName)                \
568     do                                               \
569     {                                                \
570         switch (textureType)                         \
571         {                                            \
572             case TBasicType::EbtFloat:               \
573                 name = "metal::" baseName "<float>"; \
574                 break;                               \
575             case TBasicType::EbtInt:                 \
576                 name = "metal::" baseName "<int>";   \
577                 break;                               \
578             case TBasicType::EbtUInt:                \
579                 name = "metal::" baseName "<uint>";  \
580                 break;                               \
581             default:                                 \
582                 UNREACHABLE();                       \
583                 name = nullptr;                      \
584                 break;                               \
585         }                                            \
586     } while (false)
587 
588     switch (samplerType)
589     {
590         // 1d
591         case EbtSampler1D:  // Desktop GLSL sampler type:
592         case EbtISampler1D:
593         case EbtUSampler1D:
594             HANDLE_TEXTURE_NAME("texture1d");
595             break;
596 
597         // 1d array
598         case EbtSampler1DArray:
599         case EbtISampler1DArray:
600         case EbtUSampler1DArray:
601             HANDLE_TEXTURE_NAME("texture1d_array");
602             break;
603 
604         // Buffer textures
605         case EbtSamplerBuffer:
606         case EbtISamplerBuffer:
607         case EbtUSamplerBuffer:
608             HANDLE_TEXTURE_NAME("texture_buffer");
609             break;
610 
611         // 2d textures
612         case EbtSampler2D:
613         case EbtISampler2D:
614         case EbtUSampler2D:
615         case EbtSampler2DRect:
616         case EbtUSampler2DRect:
617         case EbtISampler2DRect:
618             HANDLE_TEXTURE_NAME("texture2d");
619             break;
620 
621         // 3d textures
622         case EbtSampler3D:
623         case EbtISampler3D:
624         case EbtUSampler3D:
625             HANDLE_TEXTURE_NAME("texture3d");
626             break;
627 
628         // Cube textures
629         case EbtSamplerCube:
630         case EbtISamplerCube:
631         case EbtUSamplerCube:
632             HANDLE_TEXTURE_NAME("texturecube");
633             break;
634 
635         // 2d array textures
636         case EbtSampler2DArray:
637         case EbtUSampler2DArray:
638         case EbtISampler2DArray:
639             HANDLE_TEXTURE_NAME("texture2d_array");
640             break;
641 
642         case EbtSampler2DMS:
643         case EbtISampler2DMS:
644         case EbtUSampler2DMS:
645             HANDLE_TEXTURE_NAME("texture2d_ms");
646             break;
647 
648         case EbtSampler2DMSArray:
649         case EbtISampler2DMSArray:
650         case EbtUSampler2DMSArray:
651             HANDLE_TEXTURE_NAME("texture2d_ms_array");
652             break;
653 
654         // cube array
655         case EbtSamplerCubeArray:
656         case EbtISamplerCubeArray:
657         case EbtUSamplerCubeArray:
658             HANDLE_TEXTURE_NAME("texturecube_array");
659             break;
660 
661         // Shadow
662         case EbtSampler1DShadow:
663         case EbtSampler1DArrayShadow:
664             UNIMPLEMENTED();
665             HANDLE_TEXTURE_NAME("TODO");
666             break;
667 
668         case EbtSampler2DRectShadow:
669         case EbtSampler2DShadow:
670             HANDLE_TEXTURE_NAME("depth2d");
671             break;
672 
673         case EbtSamplerCubeShadow:
674             HANDLE_TEXTURE_NAME("depthcube");
675             break;
676 
677         case EbtSampler2DArrayShadow:
678             HANDLE_TEXTURE_NAME("depth2d_array");
679             break;
680 
681         case EbtSamplerCubeArrayShadow:
682             HANDLE_TEXTURE_NAME("depthcube_array");
683             break;
684 
685         // Extentions
686         case EbtSamplerExternalOES:       // Only valid if OES_EGL_image_external exists:
687         case EbtSamplerExternal2DY2YEXT:  // Only valid if GL_EXT_YUV_target exists:
688         case EbtSamplerVideoWEBGL:
689             UNIMPLEMENTED();
690             HANDLE_TEXTURE_NAME("TODO");
691             break;
692 
693         default:
694             UNREACHABLE();
695             name = nullptr;
696             break;
697     }
698 
699 #undef HANDLE_TEXTURE_NAME
700 
701     return Name(name, SymbolType::UserDefined);
702 }
703