1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <cstring>
9 #include <numeric>
10 #include <unordered_map>
11 #include <unordered_set>
12 
13 #include "compiler/translator/Compiler.h"
14 #include "compiler/translator/TranslatorMetalDirect.h"
15 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
16 #include "compiler/translator/TranslatorMetalDirect/ModifyStruct.h"
17 
18 using namespace sh;
19 
20 ////////////////////////////////////////////////////////////////////////////////
21 
size() const22 size_t ModifiedStructMachineries::size() const
23 {
24     return ordering.size();
25 }
26 
at(size_t index) const27 const ModifiedStructMachinery &ModifiedStructMachineries::at(size_t index) const
28 {
29     ASSERT(index < size());
30     const TStructure *s              = ordering[index];
31     const ModifiedStructMachinery *m = find(*s);
32     ASSERT(m);
33     return *m;
34 }
35 
find(const TStructure & s) const36 const ModifiedStructMachinery *ModifiedStructMachineries::find(const TStructure &s) const
37 {
38     auto iter = originalToMachinery.find(&s);
39     if (iter == originalToMachinery.end())
40     {
41         return nullptr;
42     }
43     return &iter->second;
44 }
45 
insert(const TStructure & s,const ModifiedStructMachinery & machinery)46 void ModifiedStructMachineries::insert(const TStructure &s,
47                                        const ModifiedStructMachinery &machinery)
48 {
49     ASSERT(!find(s));
50     originalToMachinery[&s] = machinery;
51     ordering.push_back(&s);
52 }
53 
54 ////////////////////////////////////////////////////////////////////////////////
55 
56 namespace
57 {
58 
Flatten(SymbolEnv & symbolEnv,TIntermTyped & node)59 TIntermTyped &Flatten(SymbolEnv &symbolEnv, TIntermTyped &node)
60 {
61     auto &type = node.getType();
62     ASSERT(type.isArray());
63 
64     auto &retType = InnermostType(type);
65     retType.makeArray(1);
66 
67     return symbolEnv.callFunctionOverload(Name("flatten"), retType, *new TIntermSequence{&node});
68 }
69 
70 struct FlattenArray
71 {};
72 
73 struct PathItem
74 {
75     enum class Type
76     {
77         Field,         // Struct field indexing.
78         Index,         // Array, vector, or matrix indexing.
79         FlattenArray,  // Array of any rank -> pointer of innermost type.
80     };
81 
PathItem__anon383a77860111::PathItem82     PathItem(const TField &field) : field(&field), type(Type::Field) {}
PathItem__anon383a77860111::PathItem83     PathItem(int index) : index(index), type(Type::Index) {}
PathItem__anon383a77860111::PathItem84     PathItem(unsigned index) : PathItem(static_cast<int>(index)) {}
PathItem__anon383a77860111::PathItem85     PathItem(FlattenArray flatten) : type(Type::FlattenArray) {}
86 
87     union
88     {
89         const TField *field;
90         int index;
91     };
92     Type type;
93 };
94 
BuildPathAccess(SymbolEnv & symbolEnv,const TVariable & var,const std::vector<PathItem> & path)95 TIntermTyped &BuildPathAccess(SymbolEnv &symbolEnv,
96                               const TVariable &var,
97                               const std::vector<PathItem> &path)
98 {
99     TIntermTyped *curr = new TIntermSymbol(&var);
100     for (const PathItem &item : path)
101     {
102         switch (item.type)
103         {
104             case PathItem::Type::Field:
105                 curr = &AccessField(*curr, item.field->name());
106                 break;
107             case PathItem::Type::Index:
108                 curr = &AccessIndex(*curr, item.index);
109                 break;
110             case PathItem::Type::FlattenArray:
111             {
112                 curr = &Flatten(symbolEnv, *curr);
113             }
114             break;
115         }
116     }
117     return *curr;
118 }
119 
120 ////////////////////////////////////////////////////////////////////////////////
121 
122 using OriginalParam = const TVariable &;
123 using ModifiedParam = const TVariable &;
124 
125 using OriginalAccess = TIntermTyped;
126 using ModifiedAccess = TIntermTyped;
127 
128 struct Access
129 {
130     OriginalAccess &original;
131     ModifiedAccess &modified;
132 
133     struct Env
134     {
135         const ConvertType type;
136     };
137 };
138 
139 using ConversionFunc = std::function<Access(Access::Env &, OriginalAccess &, ModifiedAccess &)>;
140 
141 class ConvertStructState : angle::NonCopyable
142 {
143   private:
144     struct ConversionInfo
145     {
146         ConversionFunc stdFunc;
147         const TFunction *astFunc;
148         std::vector<PathItem> pathItems;
149         ImmutableString pathName;
150     };
151 
152   public:
ConvertStructState(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const ModifyStructConfig & config,ModifiedStructMachineries & outMachineries,const bool isUBO)153     ConvertStructState(TCompiler &compiler,
154                        SymbolEnv &symbolEnv,
155                        IdGen &idGen,
156                        const ModifyStructConfig &config,
157                        ModifiedStructMachineries &outMachineries,
158                        const bool isUBO)
159         : mCompiler(compiler),
160           config(config),
161           symbolEnv(symbolEnv),
162           modifiedFields(*new TFieldList()),
163           symbolTable(symbolEnv.symbolTable()),
164           idGen(idGen),
165           outMachineries(outMachineries),
166           isUBO(isUBO)
167     {}
168 
~ConvertStructState()169     ~ConvertStructState()
170     {
171         ASSERT(namePath.empty());
172         ASSERT(namePathSizes.empty());
173     }
174 
publish(const TStructure & originalStruct,const Name & modifiedStructName)175     void publish(const TStructure &originalStruct, const Name &modifiedStructName)
176     {
177         const bool isOriginalToModified = config.convertType == ConvertType::OriginalToModified;
178 
179         auto &modifiedStruct = *new TStructure(&symbolTable, modifiedStructName.rawName(),
180                                                &modifiedFields, modifiedStructName.symbolType());
181 
182         auto &func = *new TFunction(
183             &symbolTable,
184             idGen.createNewName(isOriginalToModified ? "originalToModified" : "modifiedToOriginal")
185                 .rawName(),
186             SymbolType::AngleInternal, new TType(TBasicType::EbtVoid), false);
187 
188         OriginalParam originalParam =
189             CreateInstanceVariable(symbolTable, originalStruct, Name("original"));
190         ModifiedParam modifiedParam =
191             CreateInstanceVariable(symbolTable, modifiedStruct, Name("modified"));
192 
193         symbolEnv.markAsReference(originalParam, AddressSpace::Thread);
194         symbolEnv.markAsReference(modifiedParam, config.externalAddressSpace);
195         if (isOriginalToModified)
196         {
197             func.addParameter(&originalParam);
198             func.addParameter(&modifiedParam);
199         }
200         else
201         {
202             func.addParameter(&modifiedParam);
203             func.addParameter(&originalParam);
204         }
205 
206         TIntermBlock &body = *new TIntermBlock();
207 
208         Access::Env env{config.convertType};
209 
210         for (ConversionInfo &info : conversionInfos)
211         {
212             auto convert = [&](OriginalAccess &original, ModifiedAccess &modified) {
213                 if (info.astFunc)
214                 {
215                     ASSERT(!info.stdFunc);
216                     TIntermTyped &src  = isOriginalToModified ? modified : original;
217                     TIntermTyped &dest = isOriginalToModified ? original : modified;
218                     body.appendStatement(TIntermAggregate::CreateFunctionCall(
219                         *info.astFunc, new TIntermSequence{&dest, &src}));
220                 }
221                 else
222                 {
223                     ASSERT(info.stdFunc);
224                     Access access      = info.stdFunc(env, original, modified);
225                     TIntermTyped &src  = isOriginalToModified ? access.original : access.modified;
226                     TIntermTyped &dest = isOriginalToModified ? access.modified : access.original;
227                     body.appendStatement(new TIntermBinary(TOperator::EOpAssign, &dest, &src));
228                 }
229             };
230 
231             OriginalAccess *original = &BuildPathAccess(symbolEnv, originalParam, info.pathItems);
232             ModifiedAccess *modified = &AccessField(modifiedParam, info.pathName);
233 
234             const TType ot = original->getType();
235             const TType mt = modified->getType();
236             ASSERT(ot.isArray() == mt.isArray());
237 
238             if (ot.isArray() && (ot.getLayoutQualifier().matrixPacking == EmpRowMajor || ot != mt))
239             {
240                 ASSERT(ot.getArraySizes() == mt.getArraySizes());
241                 if (ot.isArrayOfArrays())
242                 {
243                     original = &Flatten(symbolEnv, *original);
244                     modified = &Flatten(symbolEnv, *modified);
245                 }
246                 const int volume = static_cast<int>(ot.getArraySizeProduct());
247                 for (int i = 0; i < volume; ++i)
248                 {
249                     if (i != 0)
250                     {
251                         original = original->deepCopy();
252                         modified = modified->deepCopy();
253                     }
254                     OriginalAccess &o = AccessIndex(*original, i);
255                     OriginalAccess &m = AccessIndex(*modified, i);
256                     convert(o, m);
257                 }
258             }
259             else
260             {
261                 convert(*original, *modified);
262             }
263         }
264 
265         auto *funcProto = new TIntermFunctionPrototype(&func);
266         auto *funcDef   = new TIntermFunctionDefinition(funcProto, &body);
267 
268         ModifiedStructMachinery machinery;
269         machinery.modifiedStruct                   = &modifiedStruct;
270         machinery.getConverter(config.convertType) = funcDef;
271 
272         outMachineries.insert(originalStruct, machinery);
273     }
274 
pushPath(PathItem const & item)275     void pushPath(PathItem const &item)
276     {
277         pathItems.push_back(item);
278 
279         switch (item.type)
280         {
281             case PathItem::Type::Field:
282                 pushNamePath(item.field->name().data());
283                 break;
284 
285             case PathItem::Type::Index:
286                 pushNamePath(item.index);
287                 break;
288 
289             case PathItem::Type::FlattenArray:
290                 namePathSizes.push_back(namePath.size());
291                 break;
292         }
293     }
294 
popPath()295     void popPath()
296     {
297         ASSERT(!namePath.empty());
298         ASSERT(!namePathSizes.empty());
299         namePath.resize(namePathSizes.back());
300         namePathSizes.pop_back();
301 
302         ASSERT(!pathItems.empty());
303         pathItems.pop_back();
304     }
305 
finalize(const bool allowPadding)306     void finalize(const bool allowPadding)
307     {
308         ASSERT(!finalized);
309         finalized = true;
310         introducePacking();
311         ASSERT(metalLayoutTotal == Layout::Identity());
312         // Only pad substructs. We don't want to pad the structure that contains all the UBOs, only
313         // individual UBOs.
314         if (allowPadding)
315             introducePadding();
316     }
317 
addModifiedField(const TField & field,TType & newType,TLayoutBlockStorage storage,TLayoutMatrixPacking packing,const AddressSpace * addressSpace)318     void addModifiedField(const TField &field,
319                           TType &newType,
320                           TLayoutBlockStorage storage,
321                           TLayoutMatrixPacking packing,
322                           const AddressSpace *addressSpace)
323     {
324         TLayoutQualifier layoutQualifier = newType.getLayoutQualifier();
325         layoutQualifier.blockStorage     = storage;
326         layoutQualifier.matrixPacking    = packing;
327         newType.setLayoutQualifier(layoutQualifier);
328 
329         const ImmutableString pathName(namePath);
330         TField *modifiedField = new TField(&newType, pathName, field.line(), field.symbolType());
331         if (addressSpace)
332         {
333             symbolEnv.markAsPointer(*modifiedField, *addressSpace);
334         }
335         if (symbolEnv.isUBO(field))
336         {
337             symbolEnv.markAsUBO(*modifiedField);
338         }
339         modifiedFields.push_back(modifiedField);
340     }
341 
addConversion(const ConversionFunc & func)342     void addConversion(const ConversionFunc &func)
343     {
344         ASSERT(!modifiedFields.empty());
345         conversionInfos.push_back({func, nullptr, pathItems, modifiedFields.back()->name()});
346     }
347 
addConversion(const TFunction & func)348     void addConversion(const TFunction &func)
349     {
350         ASSERT(!modifiedFields.empty());
351         conversionInfos.push_back({{}, &func, pathItems, modifiedFields.back()->name()});
352     }
353 
hasPacking() const354     bool hasPacking() const { return containsPacked; }
355 
hasPadding() const356     bool hasPadding() const { return padFieldCount > 0; }
357 
recurse(const TStructure & structure,ModifiedStructMachinery & outMachinery,const bool isUBORecurse)358     bool recurse(const TStructure &structure,
359                  ModifiedStructMachinery &outMachinery,
360                  const bool isUBORecurse)
361     {
362         const ModifiedStructMachinery *m = outMachineries.find(structure);
363         if (m == nullptr)
364         {
365             TranslatorMetalReflection *reflection =
366                 ((sh::TranslatorMetalDirect *)&mCompiler)->getTranslatorMetalReflection();
367             reflection->addOriginalName(structure.uniqueId().get(), structure.name().data());
368             const Name name = idGen.createNewName(structure.name().data());
369             if (!TryCreateModifiedStruct(mCompiler, symbolEnv, idGen, config, structure, name,
370                                          outMachineries, isUBORecurse, true))
371             {
372                 return false;
373             }
374             m = outMachineries.find(structure);
375             ASSERT(m);
376         }
377         outMachinery = *m;
378         return true;
379     }
380 
getIsUBO() const381     bool getIsUBO() const { return isUBO; }
382 
383   private:
addPadding(size_t padAmount,bool updateLayout)384     void addPadding(size_t padAmount, bool updateLayout)
385     {
386         if (padAmount == 0)
387         {
388             return;
389         }
390 
391         const size_t begin = modifiedFields.size();
392 
393         // Iteratively adding in scalar or vector padding because some struct types will not
394         // allow matrix or array members.
395         while (padAmount > 0)
396         {
397             TType *padType;
398             if (padAmount >= 16)
399             {
400                 padAmount -= 16;
401                 padType = new TType(TBasicType::EbtFloat, 4);
402             }
403             else if (padAmount >= 8)
404             {
405                 padAmount -= 8;
406                 padType = new TType(TBasicType::EbtFloat, 2);
407             }
408             else if (padAmount >= 4)
409             {
410                 padAmount -= 4;
411                 padType = new TType(TBasicType::EbtFloat);
412             }
413             else if (padAmount >= 2)
414             {
415                 padAmount -= 2;
416                 padType = new TType(TBasicType::EbtBool, 2);
417             }
418             else
419             {
420                 ASSERT(padAmount == 1);
421                 padAmount -= 1;
422                 padType = new TType(TBasicType::EbtBool);
423             }
424 
425             if (updateLayout)
426             {
427                 metalLayoutTotal += MetalLayoutOf(*padType);
428             }
429 
430             const Name name = idGen.createNewName("pad");
431             modifiedFields.push_back(
432                 new TField(padType, name.rawName(), kNoSourceLoc, name.symbolType()));
433             ++padFieldCount;
434         }
435 
436         std::reverse(modifiedFields.begin() + begin, modifiedFields.end());
437     }
438 
introducePacking()439     void introducePacking()
440     {
441         if (!config.allowPacking)
442         {
443             return;
444         }
445 
446         auto setUnpackedStorage = [](TType &type) {
447             TLayoutBlockStorage storage = type.getLayoutQualifier().blockStorage;
448             switch (storage)
449             {
450                 case TLayoutBlockStorage::EbsShared:
451                     storage = TLayoutBlockStorage::EbsStd140;
452                     break;
453                 case TLayoutBlockStorage::EbsPacked:
454                     storage = TLayoutBlockStorage::EbsStd430;
455                     break;
456                 case TLayoutBlockStorage::EbsStd140:
457                 case TLayoutBlockStorage::EbsStd430:
458                 case TLayoutBlockStorage::EbsUnspecified:
459                     break;
460             }
461             SetBlockStorage(type, storage);
462         };
463 
464         Layout glslLayoutTotal = Layout::Identity();
465         const size_t size      = modifiedFields.size();
466 
467         for (size_t i = 0; i < size; ++i)
468         {
469             TField &curr           = *modifiedFields[i];
470             TType &currType        = *curr.type();
471             const bool canBePacked = CanBePacked(currType);
472 
473             auto dontPack = [&]() {
474                 if (canBePacked)
475                 {
476                     setUnpackedStorage(currType);
477                 }
478                 glslLayoutTotal += GlslLayoutOf(currType);
479             };
480 
481             if (!CanBePacked(currType))
482             {
483                 dontPack();
484                 continue;
485             }
486 
487             const Layout packedGlslLayout           = GlslLayoutOf(currType);
488             const TLayoutBlockStorage packedStorage = currType.getLayoutQualifier().blockStorage;
489             setUnpackedStorage(currType);
490             const Layout unpackedGlslLayout = GlslLayoutOf(currType);
491             SetBlockStorage(currType, packedStorage);
492 
493             ASSERT(packedGlslLayout.sizeOf <= unpackedGlslLayout.sizeOf);
494             if (packedGlslLayout.sizeOf == unpackedGlslLayout.sizeOf)
495             {
496                 dontPack();
497                 continue;
498             }
499 
500             const size_t j = i + 1;
501             if (j == size)
502             {
503                 dontPack();
504                 break;
505             }
506 
507             const size_t pad            = unpackedGlslLayout.sizeOf - packedGlslLayout.sizeOf;
508             const TField &next          = *modifiedFields[j];
509             const Layout nextGlslLayout = GlslLayoutOf(*next.type());
510 
511             if (pad < nextGlslLayout.sizeOf)
512             {
513                 dontPack();
514                 continue;
515             }
516 
517             symbolEnv.markAsPacked(curr);
518             glslLayoutTotal += packedGlslLayout;
519             containsPacked = true;
520         }
521     }
522 
introducePadding()523     void introducePadding()
524     {
525         if (!config.allowPadding)
526         {
527             return;
528         }
529 
530         MetalLayoutOfConfig layoutConfig;
531         layoutConfig.disablePacking             = !config.allowPacking;
532         layoutConfig.assumeStructsAreTailPadded = true;
533 
534         TFieldList fields = std::move(modifiedFields);
535         ASSERT(!fields.empty());  // GLSL requires at least one member.
536 
537         const TField *const first = fields.front();
538 
539         for (TField *field : fields)
540         {
541             const TType &type = *field->type();
542 
543             const Layout glslLayout  = GlslLayoutOf(type);
544             const Layout metalLayout = MetalLayoutOf(type, layoutConfig);
545 
546             size_t prePadAmount = 0;
547             if (glslLayout.alignOf > metalLayout.alignOf && field != first)
548             {
549                 const size_t prePaddedSize = metalLayoutTotal.sizeOf;
550                 metalLayoutTotal.requireAlignment(glslLayout.alignOf, true);
551                 const size_t paddedSize = metalLayoutTotal.sizeOf;
552                 prePadAmount            = paddedSize - prePaddedSize;
553                 metalLayoutTotal += metalLayout;
554                 addPadding(prePadAmount, false);  // Note: requireAlignment() already updated layout
555             }
556             else
557             {
558                 metalLayoutTotal += metalLayout;
559             }
560 
561             modifiedFields.push_back(field);
562 
563             if (glslLayout.sizeOf > metalLayout.sizeOf && field != fields.back())
564             {
565                 const bool updateLayout = true;  // XXX: Correct?
566                 const size_t padAmount  = glslLayout.sizeOf - metalLayout.sizeOf;
567                 addPadding(padAmount, updateLayout);
568             }
569         }
570     }
571 
pushNamePath(const char * extra)572     void pushNamePath(const char *extra)
573     {
574         ASSERT(extra && *extra != '\0');
575         namePathSizes.push_back(namePath.size());
576         const char *p = extra;
577         if (namePath.empty())
578         {
579             namePath = p;
580             return;
581         }
582         while (*p == '_')
583         {
584             ++p;
585         }
586         if (*p == '\0')
587         {
588             p = "x";
589         }
590         if (namePath.back() != '_')
591         {
592             namePath += '_';
593         }
594         namePath += p;
595     }
596 
pushNamePath(unsigned extra)597     void pushNamePath(unsigned extra)
598     {
599         char buffer[std::numeric_limits<unsigned>::digits10 + 1];
600         sprintf(buffer, "%u", extra);
601         pushNamePath(buffer);
602     }
603 
604   public:
605     TCompiler &mCompiler;
606     const ModifyStructConfig &config;
607     SymbolEnv &symbolEnv;
608 
609   private:
610     TFieldList &modifiedFields;
611     Layout metalLayoutTotal = Layout::Identity();
612     size_t padFieldCount    = 0;
613     bool containsPacked     = false;
614     bool finalized          = false;
615 
616     std::vector<PathItem> pathItems;
617 
618     std::vector<size_t> namePathSizes;
619     std::string namePath;
620 
621     std::vector<ConversionInfo> conversionInfos;
622     TSymbolTable &symbolTable;
623     IdGen &idGen;
624     ModifiedStructMachineries &outMachineries;
625     const bool isUBO;
626 };
627 
628 ////////////////////////////////////////////////////////////////////////////////
629 
630 using ModifyFunc = bool (*)(ConvertStructState &state,
631                             const TField &field,
632                             const TLayoutBlockStorage storage,
633                             const TLayoutMatrixPacking packing);
634 
635 bool ModifyRecursive(ConvertStructState &state,
636                      const TField &field,
637                      const TLayoutBlockStorage storage,
638                      const TLayoutMatrixPacking packing);
639 
IdentityModify(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)640 bool IdentityModify(ConvertStructState &state,
641                     const TField &field,
642                     const TLayoutBlockStorage storage,
643                     const TLayoutMatrixPacking packing)
644 {
645     const TType &type = *field.type();
646     state.addModifiedField(field, CloneType(type), storage, packing, nullptr);
647     state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
648         return Access{o, m};
649     });
650     return false;
651 }
652 
InlineStruct(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)653 bool InlineStruct(ConvertStructState &state,
654                   const TField &field,
655                   const TLayoutBlockStorage storage,
656                   const TLayoutMatrixPacking packing)
657 {
658     const TType &type              = *field.type();
659     const TStructure *substructure = state.symbolEnv.remap(type.getStruct());
660     if (!substructure)
661     {
662         return false;
663     }
664     if (type.isArray())
665     {
666         return false;
667     }
668     if (!state.config.inlineStruct(field))
669     {
670         return false;
671     }
672 
673     const TFieldList &subfields = substructure->fields();
674     for (const TField *subfield : subfields)
675     {
676         const TType &subtype                  = *subfield->type();
677         const TLayoutBlockStorage substorage  = Overlay(storage, subtype);
678         const TLayoutMatrixPacking subpacking = Overlay(packing, subtype);
679         ModifyRecursive(state, *subfield, substorage, subpacking);
680     }
681 
682     return true;
683 }
684 
RecurseStruct(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)685 bool RecurseStruct(ConvertStructState &state,
686                    const TField &field,
687                    const TLayoutBlockStorage storage,
688                    const TLayoutMatrixPacking packing)
689 {
690     const TType &type              = *field.type();
691     const TStructure *substructure = state.symbolEnv.remap(type.getStruct());
692     if (!substructure)
693     {
694         return false;
695     }
696     if (!state.config.recurseStruct(field))
697     {
698         return false;
699     }
700 
701     ModifiedStructMachinery machinery;
702     if (!state.recurse(*substructure, machinery, state.getIsUBO()))
703     {
704         return false;
705     }
706 
707     TType &newType = *new TType(machinery.modifiedStruct, false);
708     if (type.isArray())
709     {
710         newType.makeArrays(type.getArraySizes());
711     }
712 
713     TIntermFunctionDefinition *converter = machinery.getConverter(state.config.convertType);
714     ASSERT(converter);
715 
716     state.addModifiedField(field, newType, storage, packing, state.symbolEnv.isPointer(field));
717     if (state.symbolEnv.isPointer(field))
718     {
719         state.symbolEnv.removePointer(field);
720     }
721     state.addConversion(*converter->getFunction());
722 
723     return true;
724 }
725 
SplitMatrixColumns(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)726 bool SplitMatrixColumns(ConvertStructState &state,
727                         const TField &field,
728                         const TLayoutBlockStorage storage,
729                         const TLayoutMatrixPacking packing)
730 {
731     const TType &type = *field.type();
732     if (!type.isMatrix())
733     {
734         return false;
735     }
736 
737     if (!state.config.splitMatrixColumns(field))
738     {
739         return false;
740     }
741 
742     const int cols = type.getCols();
743     TType &rowType = DropColumns(type);
744 
745     for (int c = 0; c < cols; ++c)
746     {
747         state.pushPath(c);
748 
749         state.addModifiedField(field, rowType, storage, packing, state.symbolEnv.isPointer(field));
750         if (state.symbolEnv.isPointer(field))
751         {
752             state.symbolEnv.removePointer(field);
753         }
754         state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
755             return Access{o, m};
756         });
757 
758         state.popPath();
759     }
760 
761     return true;
762 }
763 
SaturateMatrixRows(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)764 bool SaturateMatrixRows(ConvertStructState &state,
765                         const TField &field,
766                         const TLayoutBlockStorage storage,
767                         const TLayoutMatrixPacking packing)
768 {
769     const TType &type = *field.type();
770     if (!type.isMatrix())
771     {
772         return false;
773     }
774     const bool isRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
775     const int rows        = type.getRows();
776     const int saturation  = state.config.saturateMatrixRows(field);
777     if (saturation <= rows && !isRowMajor)
778     {
779         return false;
780     }
781 
782     const int cols = type.getCols();
783     TType &satType = SetMatrixRowDim(type, saturation);
784     state.addModifiedField(field, satType, storage, packing, state.symbolEnv.isPointer(field));
785     if (state.symbolEnv.isPointer(field))
786     {
787         state.symbolEnv.removePointer(field);
788     }
789 
790     for (int c = 0; c < cols; ++c)
791     {
792         for (int r = 0; r < rows; ++r)
793         {
794             state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
795                 int firstModifiedIndex  = isRowMajor ? r : c;
796                 int secondModifiedIndex = isRowMajor ? c : r;
797                 auto &o_                = AccessIndex(AccessIndex(o, c), r);
798                 auto &m_ = AccessIndex(AccessIndex(m, firstModifiedIndex), secondModifiedIndex);
799                 return Access{o_, m_};
800             });
801         }
802     }
803 
804     return true;
805 }
806 
TestBoolToUint(ConvertStructState & state,const TField & field)807 bool TestBoolToUint(ConvertStructState &state, const TField &field)
808 {
809     if (field.type()->getBasicType() != TBasicType::EbtBool)
810     {
811         return false;
812     }
813     if (!state.config.promoteBoolToUint(field))
814     {
815         return false;
816     }
817     return true;
818 }
819 
ConvertBoolToUint(ConvertType convertType,OriginalAccess & o,ModifiedAccess & m)820 Access ConvertBoolToUint(ConvertType convertType, OriginalAccess &o, ModifiedAccess &m)
821 {
822     auto coerce = [](TIntermTyped &to, TIntermTyped &from) -> TIntermTyped & {
823         return *TIntermAggregate::CreateConstructor(to.getType(), new TIntermSequence{&from});
824     };
825     switch (convertType)
826     {
827         case ConvertType::OriginalToModified:
828             return Access{coerce(m, o), m};
829         case ConvertType::ModifiedToOriginal:
830             return Access{o, coerce(o, m)};
831     }
832 }
833 
SaturateScalarOrVectorCommon(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing,const bool array)834 bool SaturateScalarOrVectorCommon(ConvertStructState &state,
835                                   const TField &field,
836                                   const TLayoutBlockStorage storage,
837                                   const TLayoutMatrixPacking packing,
838                                   const bool array)
839 {
840     const TType &type = *field.type();
841     if (type.isArray() != array)
842     {
843         return false;
844     }
845     if (!((type.isRank0() && HasScalarBasicType(type)) || type.isVector()))
846     {
847         return false;
848     }
849     const auto saturator =
850         array ? state.config.saturateScalarOrVectorArrays : state.config.saturateScalarOrVector;
851     const int dim        = type.getNominalSize();
852     const int saturation = saturator(field);
853     if (saturation <= dim)
854     {
855         return false;
856     }
857 
858     TType &satType        = SetVectorDim(type, saturation);
859     const bool boolToUint = TestBoolToUint(state, field);
860     if (boolToUint)
861     {
862         satType.setBasicType(TBasicType::EbtUInt);
863     }
864     state.addModifiedField(field, satType, storage, packing, state.symbolEnv.isPointer(field));
865     if (state.symbolEnv.isPointer(field))
866     {
867         state.symbolEnv.removePointer(field);
868     }
869 
870     for (int d = 0; d < dim; ++d)
871     {
872         state.addConversion([=](Access::Env &env, OriginalAccess &o, ModifiedAccess &m) {
873             auto &o_ = dim > 1 ? AccessIndex(o, d) : o;
874             auto &m_ = AccessIndex(m, d);
875             if (boolToUint)
876             {
877                 return ConvertBoolToUint(env.type, o_, m_);
878             }
879             else
880             {
881                 return Access{o_, m_};
882             }
883         });
884     }
885 
886     return true;
887 }
888 
SaturateScalarOrVectorArrays(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)889 bool SaturateScalarOrVectorArrays(ConvertStructState &state,
890                                   const TField &field,
891                                   const TLayoutBlockStorage storage,
892                                   const TLayoutMatrixPacking packing)
893 {
894     return SaturateScalarOrVectorCommon(state, field, storage, packing, true);
895 }
896 
SaturateScalarOrVector(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)897 bool SaturateScalarOrVector(ConvertStructState &state,
898                             const TField &field,
899                             const TLayoutBlockStorage storage,
900                             const TLayoutMatrixPacking packing)
901 {
902     return SaturateScalarOrVectorCommon(state, field, storage, packing, false);
903 }
904 
PromoteBoolToUint(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)905 bool PromoteBoolToUint(ConvertStructState &state,
906                        const TField &field,
907                        const TLayoutBlockStorage storage,
908                        const TLayoutMatrixPacking packing)
909 {
910     if (!TestBoolToUint(state, field))
911     {
912         return false;
913     }
914 
915     auto &promotedType = CloneType(*field.type());
916     promotedType.setBasicType(TBasicType::EbtUInt);
917     state.addModifiedField(field, promotedType, storage, packing, state.symbolEnv.isPointer(field));
918     if (state.symbolEnv.isPointer(field))
919     {
920         state.symbolEnv.removePointer(field);
921     }
922 
923     state.addConversion([=](Access::Env &env, OriginalAccess &o, ModifiedAccess &m) {
924         return ConvertBoolToUint(env.type, o, m);
925     });
926 
927     return true;
928 }
929 
ModifyCommon(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)930 bool ModifyCommon(ConvertStructState &state,
931                   const TField &field,
932                   const TLayoutBlockStorage storage,
933                   const TLayoutMatrixPacking packing)
934 {
935     ModifyFunc funcs[] = {
936         InlineStruct,                  //
937         RecurseStruct,                 //
938         SplitMatrixColumns,            //
939         SaturateMatrixRows,            //
940         SaturateScalarOrVectorArrays,  //
941         SaturateScalarOrVector,        //
942         PromoteBoolToUint,             //
943     };
944 
945     for (ModifyFunc func : funcs)
946     {
947         if (func(state, field, storage, packing))
948         {
949             return true;
950         }
951     }
952 
953     return IdentityModify(state, field, storage, packing);
954 }
955 
InlineArray(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)956 bool InlineArray(ConvertStructState &state,
957                  const TField &field,
958                  const TLayoutBlockStorage storage,
959                  const TLayoutMatrixPacking packing)
960 {
961     const TType &type = *field.type();
962     if (!type.isArray())
963     {
964         return false;
965     }
966     if (!state.config.inlineArray(field))
967     {
968         return false;
969     }
970 
971     const unsigned volume = type.getArraySizeProduct();
972     const bool isMultiDim = type.isArrayOfArrays();
973 
974     auto &innermostType = InnermostType(type);
975     const TField innermostField(&innermostType, field.name(), field.line(), field.symbolType());
976 
977     if (isMultiDim)
978     {
979         state.pushPath(FlattenArray());
980     }
981 
982     for (unsigned i = 0; i < volume; ++i)
983     {
984         state.pushPath(i);
985         ModifyCommon(state, innermostField, storage, packing);
986         state.popPath();
987     }
988 
989     if (isMultiDim)
990     {
991         state.popPath();
992     }
993 
994     return true;
995 }
996 
ModifyRecursive(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)997 bool ModifyRecursive(ConvertStructState &state,
998                      const TField &field,
999                      const TLayoutBlockStorage storage,
1000                      const TLayoutMatrixPacking packing)
1001 {
1002     state.pushPath(field);
1003 
1004     bool modified;
1005     if (InlineArray(state, field, storage, packing))
1006     {
1007         modified = true;
1008     }
1009     else
1010     {
1011         modified = ModifyCommon(state, field, storage, packing);
1012     }
1013 
1014     state.popPath();
1015 
1016     return modified;
1017 }
1018 
1019 }  // anonymous namespace
1020 
1021 ////////////////////////////////////////////////////////////////////////////////
1022 
TryCreateModifiedStruct(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const ModifyStructConfig & config,const TStructure & originalStruct,const Name & modifiedStructName,ModifiedStructMachineries & outMachineries,const bool isUBO,const bool allowPadding)1023 bool sh::TryCreateModifiedStruct(TCompiler &compiler,
1024                                  SymbolEnv &symbolEnv,
1025                                  IdGen &idGen,
1026                                  const ModifyStructConfig &config,
1027                                  const TStructure &originalStruct,
1028                                  const Name &modifiedStructName,
1029                                  ModifiedStructMachineries &outMachineries,
1030                                  const bool isUBO,
1031                                  const bool allowPadding)
1032 {
1033     ConvertStructState state(compiler, symbolEnv, idGen, config, outMachineries, isUBO);
1034     size_t identicalFieldCount = 0;
1035 
1036     const TFieldList &originalFields = originalStruct.fields();
1037     for (TField *originalField : originalFields)
1038     {
1039         const TType &originalType          = *originalField->type();
1040         const TLayoutBlockStorage storage  = Overlay(config.initialBlockStorage, originalType);
1041         const TLayoutMatrixPacking packing = Overlay(config.initialMatrixPacking, originalType);
1042         if (!ModifyRecursive(state, *originalField, storage, packing))
1043         {
1044             ++identicalFieldCount;
1045         }
1046     }
1047 
1048     state.finalize(allowPadding);
1049 
1050     if (identicalFieldCount == originalFields.size() && !state.hasPacking() &&
1051         !state.hasPadding() && !isUBO)
1052     {
1053         return false;
1054     }
1055 
1056     state.publish(originalStruct, modifiedStructName);
1057 
1058     return true;
1059 }
1060