1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cstring>
8 #include <unordered_map>
9 #include <unordered_set>
10 
11 #include "compiler/translator/TranslatorMetalDirect.h"
12 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
13 #include "compiler/translator/TranslatorMetalDirect/DiscoverDependentFunctions.h"
14 #include "compiler/translator/TranslatorMetalDirect/IdGen.h"
15 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
16 #include "compiler/translator/TranslatorMetalDirect/MapSymbols.h"
17 #include "compiler/translator/TranslatorMetalDirect/Pipeline.h"
18 #include "compiler/translator/TranslatorMetalDirect/RewritePipelines.h"
19 #include "compiler/translator/TranslatorMetalDirect/SymbolEnv.h"
20 #include "compiler/translator/tree_ops/PruneNoOps.h"
21 #include "compiler/translator/tree_util/DriverUniform.h"
22 #include "compiler/translator/tree_util/FindMain.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 using namespace sh;
25 
26 ////////////////////////////////////////////////////////////////////////////////
27 
28 namespace
29 {
30 
31 using VariableSet  = std::unordered_set<const TVariable *>;
32 using VariableList = std::vector<const TVariable *>;
33 
34 ////////////////////////////////////////////////////////////////////////////////
35 
36 struct PipelineStructInfo
37 {
38     VariableSet pipelineVariables;
39     PipelineScoped<TStructure> pipelineStruct;
40     const TFunction *funcOriginalToModified = nullptr;
41     const TFunction *funcModifiedToOriginal = nullptr;
42 
isEmpty__anond06a1e240111::PipelineStructInfo43     bool isEmpty() const
44     {
45         if (pipelineStruct.isTotallyEmpty())
46         {
47             ASSERT(pipelineVariables.empty());
48             return true;
49         }
50         else
51         {
52             ASSERT(pipelineStruct.isTotallyFull());
53             ASSERT(!pipelineVariables.empty());
54             return false;
55         }
56     }
57 };
58 
59 class GeneratePipelineStruct : private TIntermRebuild
60 {
61   private:
62     const Pipeline &mPipeline;
63     SymbolEnv &mSymbolEnv;
64     Invariants &mInvariants;
65     VariableList mPipelineVariableList;
66     IdGen &mIdGen;
67     PipelineStructInfo mInfo;
68 
69   public:
Exec(PipelineStructInfo & out,TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,Invariants & invariants)70     static bool Exec(PipelineStructInfo &out,
71                      TCompiler &compiler,
72                      TIntermBlock &root,
73                      IdGen &idGen,
74                      const Pipeline &pipeline,
75                      SymbolEnv &symbolEnv,
76                      Invariants &invariants)
77     {
78         GeneratePipelineStruct self(compiler, idGen, pipeline, symbolEnv, invariants);
79         if (!self.exec(root))
80         {
81             return false;
82         }
83         out = self.mInfo;
84         return true;
85     }
86 
87   private:
GeneratePipelineStruct(TCompiler & compiler,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,Invariants & invariants)88     GeneratePipelineStruct(TCompiler &compiler,
89                            IdGen &idGen,
90                            const Pipeline &pipeline,
91                            SymbolEnv &symbolEnv,
92                            Invariants &invariants)
93         : TIntermRebuild(compiler, true, true),
94           mPipeline(pipeline),
95           mSymbolEnv(symbolEnv),
96           mInvariants(invariants),
97           mIdGen(idGen)
98     {}
99 
exec(TIntermBlock & root)100     bool exec(TIntermBlock &root)
101     {
102         if (!rebuildRoot(root))
103         {
104             return false;
105         }
106 
107         if (mInfo.pipelineVariables.empty())
108         {
109             return true;
110         }
111 
112         TIntermSequence seq;
113 
114         const TStructure &pipelineStruct = [&]() -> const TStructure & {
115             if (mPipeline.globalInstanceVar)
116             {
117                 return *mPipeline.globalInstanceVar->getType().getStruct();
118             }
119             else
120             {
121                 return createInternalPipelineStruct(root, seq);
122             }
123         }();
124 
125         ModifiedStructMachineries modifiedMachineries;
126         const bool isUBO    = mPipeline.type == Pipeline::Type::UniformBuffer;
127         const bool modified = TryCreateModifiedStruct(
128             mCompiler, mSymbolEnv, mIdGen, mPipeline.externalStructModifyConfig(), pipelineStruct,
129             mPipeline.getStructTypeName(Pipeline::Variant::Modified), modifiedMachineries, isUBO,
130             !isUBO);
131 
132         if (modified)
133         {
134             ASSERT(mPipeline.type != Pipeline::Type::Texture);
135             ASSERT(mPipeline.type == Pipeline::Type::AngleUniforms ||
136                    !mPipeline.globalInstanceVar);  // This shouldn't happen by construction.
137 
138             auto getFunction = [](sh::TIntermFunctionDefinition *funcDecl) {
139                 return funcDecl ? funcDecl->getFunction() : nullptr;
140             };
141 
142             const size_t size = modifiedMachineries.size();
143             ASSERT(size > 0);
144             for (size_t i = 0; i < size; ++i)
145             {
146                 const ModifiedStructMachinery &machinery = modifiedMachineries.at(i);
147                 ASSERT(machinery.modifiedStruct);
148 
149                 seq.push_back(new TIntermDeclaration{
150                     &CreateStructTypeVariable(mSymbolTable, *machinery.modifiedStruct)});
151 
152                 if (mPipeline.isPipelineOut())
153                 {
154                     ASSERT(machinery.funcOriginalToModified);
155                     ASSERT(!machinery.funcModifiedToOriginal);
156                     seq.push_back(machinery.funcOriginalToModified);
157                 }
158                 else
159                 {
160                     ASSERT(machinery.funcModifiedToOriginal);
161                     ASSERT(!machinery.funcOriginalToModified);
162                     seq.push_back(machinery.funcModifiedToOriginal);
163                 }
164 
165                 if (i == size - 1)
166                 {
167                     mInfo.funcOriginalToModified = getFunction(machinery.funcOriginalToModified);
168                     mInfo.funcModifiedToOriginal = getFunction(machinery.funcModifiedToOriginal);
169 
170                     mInfo.pipelineStruct.internal = &pipelineStruct;
171                     mInfo.pipelineStruct.external =
172                         modified ? machinery.modifiedStruct : &pipelineStruct;
173                 }
174             }
175         }
176         else
177         {
178             mInfo.pipelineStruct.internal = &pipelineStruct;
179             mInfo.pipelineStruct.external = &pipelineStruct;
180         }
181 
182         root.insertChildNodes(FindMainIndex(&root), seq);
183 
184         return true;
185     }
186 
187   private:
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)188     PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
189     {
190         return {node, VisitBits::Neither};
191     }
visitDeclarationPost(TIntermDeclaration & declNode)192     PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
193     {
194         Declaration decl     = ViewDeclaration(declNode);
195         const TVariable &var = decl.symbol.variable();
196         if (mPipeline.uses(var))
197         {
198             ASSERT(mInfo.pipelineVariables.find(&var) == mInfo.pipelineVariables.end());
199             mInfo.pipelineVariables.insert(&var);
200             mPipelineVariableList.push_back(&var);
201             return nullptr;
202         }
203 
204         return declNode;
205     }
206 
createInternalPipelineStruct(TIntermBlock & root,TIntermSequence & outDeclSeq)207     const TStructure &createInternalPipelineStruct(TIntermBlock &root, TIntermSequence &outDeclSeq)
208     {
209         auto &fields = *new TFieldList();
210 
211         switch (mPipeline.type)
212         {
213             case Pipeline::Type::Texture:
214             {
215                 for (const TVariable *var : mPipelineVariableList)
216                 {
217                     ASSERT(!mInvariants.contains(*var));
218                     const TType &varType         = var->getType();
219                     const TBasicType samplerType = varType.getBasicType();
220 
221                     const TStructure &textureEnv = mSymbolEnv.getTextureEnv(samplerType);
222                     auto *textureEnvType         = new TType(&textureEnv, false);
223                     if (varType.isArray())
224                     {
225                         textureEnvType->makeArrays(varType.getArraySizes());
226                     }
227 
228                     fields.push_back(
229                         new TField(textureEnvType, var->name(), kNoSourceLoc, var->symbolType()));
230                 }
231             }
232             break;
233 
234             case Pipeline::Type::UniformBuffer:
235             {
236                 for (const TVariable *var : mPipelineVariableList)
237                 {
238                     auto &type  = CloneType(var->getType());
239                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
240                     mSymbolEnv.markAsPointer(*field, AddressSpace::Constant);
241                     mSymbolEnv.markAsUBO(*field);
242                     mSymbolEnv.markAsPointer(*var, AddressSpace::Constant);
243                     fields.push_back(field);
244                 }
245             }
246             break;
247             default:
248             {
249                 for (const TVariable *var : mPipelineVariableList)
250                 {
251                     auto &type  = CloneType(var->getType());
252                     auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
253                     fields.push_back(field);
254 
255                     if (mInvariants.contains(*var))
256                     {
257                         mInvariants.insert(*field);
258                     }
259                 }
260             }
261             break;
262         }
263 
264         Name pipelineStructName = mPipeline.getStructTypeName(Pipeline::Variant::Original);
265         auto &s = *new TStructure(&mSymbolTable, pipelineStructName.rawName(), &fields,
266                                   pipelineStructName.symbolType());
267 
268         outDeclSeq.push_back(new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, s)});
269 
270         return s;
271     }
272 };
273 
274 ////////////////////////////////////////////////////////////////////////////////
275 
CreatePipelineMainLocalVar(TSymbolTable & symbolTable,const Pipeline & pipeline,PipelineScoped<TStructure> pipelineStruct)276 PipelineScoped<TVariable> CreatePipelineMainLocalVar(TSymbolTable &symbolTable,
277                                                      const Pipeline &pipeline,
278                                                      PipelineScoped<TStructure> pipelineStruct)
279 {
280     ASSERT(pipelineStruct.isTotallyFull());
281 
282     PipelineScoped<TVariable> pipelineMainLocalVar;
283 
284     auto populateExternalMainLocalVar = [&]() {
285         ASSERT(!pipelineMainLocalVar.external);
286         pipelineMainLocalVar.external = &CreateInstanceVariable(
287             symbolTable, *pipelineStruct.external,
288             pipeline.getStructInstanceName(pipelineStruct.isUniform()
289                                                ? Pipeline::Variant::Original
290                                                : Pipeline::Variant::Modified));
291     };
292 
293     auto populateDistinctInternalMainLocalVar = [&]() {
294         ASSERT(!pipelineMainLocalVar.internal);
295         pipelineMainLocalVar.internal =
296             &CreateInstanceVariable(symbolTable, *pipelineStruct.internal,
297                                     pipeline.getStructInstanceName(Pipeline::Variant::Original));
298     };
299 
300     if (pipeline.type == Pipeline::Type::InstanceId)
301     {
302         populateDistinctInternalMainLocalVar();
303     }
304     else if (pipeline.alwaysRequiresLocalVariableDeclarationInMain())
305     {
306         populateExternalMainLocalVar();
307 
308         if (pipelineStruct.isUniform())
309         {
310             pipelineMainLocalVar.internal = pipelineMainLocalVar.external;
311         }
312         else
313         {
314             populateDistinctInternalMainLocalVar();
315         }
316     }
317     else if (!pipelineStruct.isUniform())
318     {
319         populateDistinctInternalMainLocalVar();
320     }
321 
322     return pipelineMainLocalVar;
323 }
324 
325 class PipelineFunctionEnv
326 {
327   private:
328     TCompiler &mCompiler;
329     SymbolEnv &mSymbolEnv;
330     TSymbolTable &mSymbolTable;
331     IdGen &mIdGen;
332     const Pipeline &mPipeline;
333     const std::unordered_set<const TFunction *> &mPipelineFunctions;
334     const PipelineScoped<TStructure> mPipelineStruct;
335     PipelineScoped<TVariable> &mPipelineMainLocalVar;
336 
337     std::unordered_map<const TFunction *, const TFunction *> mFuncMap;
338 
339   public:
PipelineFunctionEnv(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar)340     PipelineFunctionEnv(TCompiler &compiler,
341                         SymbolEnv &symbolEnv,
342                         IdGen &idGen,
343                         const Pipeline &pipeline,
344                         const std::unordered_set<const TFunction *> &pipelineFunctions,
345                         PipelineScoped<TStructure> pipelineStruct,
346                         PipelineScoped<TVariable> &pipelineMainLocalVar)
347         : mCompiler(compiler),
348           mSymbolEnv(symbolEnv),
349           mSymbolTable(symbolEnv.symbolTable()),
350           mIdGen(idGen),
351           mPipeline(pipeline),
352           mPipelineFunctions(pipelineFunctions),
353           mPipelineStruct(pipelineStruct),
354           mPipelineMainLocalVar(pipelineMainLocalVar)
355     {}
356 
isOriginalPipelineFunction(const TFunction & func) const357     bool isOriginalPipelineFunction(const TFunction &func) const
358     {
359         return mPipelineFunctions.find(&func) != mPipelineFunctions.end();
360     }
361 
isUpdatedPipelineFunction(const TFunction & func) const362     bool isUpdatedPipelineFunction(const TFunction &func) const
363     {
364         auto it = mFuncMap.find(&func);
365         if (it == mFuncMap.end())
366         {
367             return false;
368         }
369         return &func == it->second;
370     }
371 
getUpdatedFunction(const TFunction & func)372     const TFunction &getUpdatedFunction(const TFunction &func)
373     {
374         ASSERT(isOriginalPipelineFunction(func) || isUpdatedPipelineFunction(func));
375 
376         const TFunction *newFunc;
377 
378         auto it = mFuncMap.find(&func);
379         if (it == mFuncMap.end())
380         {
381             const bool isMain = func.isMain();
382 
383             if (isMain && mPipeline.isPipelineOut())
384             {
385                 ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
386                 newFunc = &CloneFunctionAndChangeReturnType(mSymbolTable, nullptr, func,
387                                                             *mPipelineStruct.external);
388             }
389             else if (isMain && (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
390                                 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals))
391             {
392                 std::vector<const TVariable *> variables;
393                 for (const TField *field : mPipelineStruct.external->fields())
394                 {
395                     variables.push_back(new TVariable(&mSymbolTable, field->name(), field->type(),
396                                                       field->symbolType()));
397                 }
398                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
399             }
400             else if (isMain && mPipeline.type == Pipeline::Type::Texture)
401             {
402                 std::vector<const TVariable *> variables;
403                 TranslatorMetalReflection *reflection =
404                     ((sh::TranslatorMetalDirect *)&mCompiler)->getTranslatorMetalReflection();
405                 for (const TField *field : mPipelineStruct.external->fields())
406                 {
407                     const TStructure *textureEnv = field->type()->getStruct();
408                     ASSERT(textureEnv && textureEnv->fields().size() == 2);
409                     for (const TField *subfield : textureEnv->fields())
410                     {
411                         const Name name = mIdGen.createNewName({field->name(), subfield->name()});
412                         TType &type     = *new TType(*subfield->type());
413                         ASSERT(!type.isArray());
414                         type.makeArrays(field->type()->getArraySizes());
415                         auto *var =
416                             new TVariable(&mSymbolTable, name.rawName(), &type, name.symbolType());
417                         variables.push_back(var);
418                         reflection->addOriginalName(var->uniqueId().get(), field->name().data());
419                     }
420                 }
421                 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
422             }
423             else if (isMain && mPipeline.type == Pipeline::Type::InstanceId)
424             {
425                 Name name = mPipeline.getStructInstanceName(Pipeline::Variant::Modified);
426                 auto *var = new TVariable(&mSymbolTable, name.rawName(),
427                                           new TType(TBasicType::EbtUInt), name.symbolType());
428                 newFunc   = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
429                 mPipelineMainLocalVar.external = var;
430             }
431             else if (isMain && mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
432             {
433                 ASSERT(mPipelineMainLocalVar.isTotallyFull());
434                 newFunc = &func;
435             }
436             else
437             {
438                 const TVariable *var;
439                 AddressSpace addressSpace;
440 
441                 if (isMain && !mPipelineMainLocalVar.isUniform())
442                 {
443                     var = &CreateInstanceVariable(
444                         mSymbolTable, *mPipelineStruct.external,
445                         mPipeline.getStructInstanceName(Pipeline::Variant::Modified));
446                     addressSpace = mPipeline.externalAddressSpace();
447                 }
448                 else
449                 {
450                     if (mPipeline.type == Pipeline::Type::UniformBuffer)
451                     {
452                         TranslatorMetalReflection *reflection =
453                             ((sh::TranslatorMetalDirect *)&mCompiler)
454                                 ->getTranslatorMetalReflection();
455                         // TODO: need more checks to make sure they line up? Could be reordered?
456                         ASSERT(mPipelineStruct.external->fields().size() ==
457                                mPipelineStruct.internal->fields().size());
458                         for (size_t i = 0; i < mPipelineStruct.external->fields().size(); i++)
459                         {
460                             const TField *externalField = mPipelineStruct.external->fields()[i];
461                             const TField *internalField = mPipelineStruct.internal->fields()[i];
462                             const TType &externalType   = *externalField->type();
463                             const TType &internalType   = *internalField->type();
464                             ASSERT(externalType.getBasicType() == internalType.getBasicType());
465                             if (externalType.getBasicType() == TBasicType::EbtStruct)
466                             {
467                                 const TStructure *externalEnv = externalType.getStruct();
468                                 const TStructure *internalEnv = internalType.getStruct();
469                                 const std::string internalName =
470                                     reflection->getOriginalName(internalEnv->uniqueId().get());
471                                 reflection->addOriginalName(externalEnv->uniqueId().get(),
472                                                             internalName);
473                             }
474                         }
475                     }
476                     var = &CreateInstanceVariable(
477                         mSymbolTable, *mPipelineStruct.internal,
478                         mPipeline.getStructInstanceName(Pipeline::Variant::Original));
479                     addressSpace = mPipelineMainLocalVar.isUniform()
480                                        ? mPipeline.externalAddressSpace()
481                                        : AddressSpace::Thread;
482                 }
483 
484                 bool markAsReference = true;
485                 if (isMain)
486                 {
487                     switch (mPipeline.type)
488                     {
489                         case Pipeline::Type::VertexIn:
490                         case Pipeline::Type::FragmentIn:
491                             markAsReference = false;
492                             break;
493 
494                         default:
495                             break;
496                     }
497                 }
498 
499                 if (markAsReference)
500                 {
501                     mSymbolEnv.markAsReference(*var, addressSpace);
502                 }
503 
504                 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
505             }
506 
507             mFuncMap[&func]   = newFunc;
508             mFuncMap[newFunc] = newFunc;
509         }
510         else
511         {
512             newFunc = it->second;
513         }
514 
515         return *newFunc;
516     }
517 
createUpdatedFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)518     TIntermFunctionPrototype *createUpdatedFunctionPrototype(
519         TIntermFunctionPrototype &funcProtoNode)
520     {
521         const TFunction &func = *funcProtoNode.getFunction();
522         if (!isOriginalPipelineFunction(func) && !isUpdatedPipelineFunction(func))
523         {
524             return nullptr;
525         }
526         const TFunction &newFunc = getUpdatedFunction(func);
527         return new TIntermFunctionPrototype(&newFunc);
528     }
529 };
530 
531 class UpdatePipelineFunctions : private TIntermRebuild
532 {
533   private:
534     const Pipeline &mPipeline;
535     const PipelineScoped<TStructure> mPipelineStruct;
536     PipelineScoped<TVariable> &mPipelineMainLocalVar;
537     SymbolEnv &mSymbolEnv;
538     PipelineFunctionEnv mEnv;
539     const TFunction *mFuncOriginalToModified;
540     const TFunction *mFuncModifiedToOriginal;
541 
542   public:
ThreadPipeline(TCompiler & compiler,TIntermBlock & root,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)543     static bool ThreadPipeline(TCompiler &compiler,
544                                TIntermBlock &root,
545                                const Pipeline &pipeline,
546                                const std::unordered_set<const TFunction *> &pipelineFunctions,
547                                PipelineScoped<TStructure> pipelineStruct,
548                                PipelineScoped<TVariable> &pipelineMainLocalVar,
549                                IdGen &idGen,
550                                SymbolEnv &symbolEnv,
551                                const TFunction *funcOriginalToModified,
552                                const TFunction *funcModifiedToOriginal)
553     {
554         UpdatePipelineFunctions self(compiler, pipeline, pipelineFunctions, pipelineStruct,
555                                      pipelineMainLocalVar, idGen, symbolEnv, funcOriginalToModified,
556                                      funcModifiedToOriginal);
557         if (!self.rebuildRoot(root))
558         {
559             return false;
560         }
561         return true;
562     }
563 
564   private:
UpdatePipelineFunctions(TCompiler & compiler,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)565     UpdatePipelineFunctions(TCompiler &compiler,
566                             const Pipeline &pipeline,
567                             const std::unordered_set<const TFunction *> &pipelineFunctions,
568                             PipelineScoped<TStructure> pipelineStruct,
569                             PipelineScoped<TVariable> &pipelineMainLocalVar,
570                             IdGen &idGen,
571                             SymbolEnv &symbolEnv,
572                             const TFunction *funcOriginalToModified,
573                             const TFunction *funcModifiedToOriginal)
574         : TIntermRebuild(compiler, false, true),
575           mPipeline(pipeline),
576           mPipelineStruct(pipelineStruct),
577           mPipelineMainLocalVar(pipelineMainLocalVar),
578           mSymbolEnv(symbolEnv),
579           mEnv(compiler,
580                symbolEnv,
581                idGen,
582                pipeline,
583                pipelineFunctions,
584                pipelineStruct,
585                mPipelineMainLocalVar),
586           mFuncOriginalToModified(funcOriginalToModified),
587           mFuncModifiedToOriginal(funcModifiedToOriginal)
588     {
589         ASSERT(mPipelineStruct.isTotallyFull());
590     }
591 
getInternalPipelineVariable(const TFunction & pipelineFunc)592     const TVariable &getInternalPipelineVariable(const TFunction &pipelineFunc)
593     {
594         if (pipelineFunc.isMain() && (mPipeline.alwaysRequiresLocalVariableDeclarationInMain() ||
595                                       !mPipelineMainLocalVar.isUniform()))
596         {
597             ASSERT(mPipelineMainLocalVar.internal);
598             return *mPipelineMainLocalVar.internal;
599         }
600         else
601         {
602             ASSERT(pipelineFunc.getParamCount() > 0);
603             return *pipelineFunc.getParam(0);
604         }
605     }
606 
getExternalPipelineVariable(const TFunction & mainFunc)607     const TVariable &getExternalPipelineVariable(const TFunction &mainFunc)
608     {
609         ASSERT(mainFunc.isMain());
610         if (mPipelineMainLocalVar.external)
611         {
612             return *mPipelineMainLocalVar.external;
613         }
614         else
615         {
616             ASSERT(mainFunc.getParamCount() > 0);
617             return *mainFunc.getParam(0);
618         }
619     }
620 
visitAggregatePost(TIntermAggregate & callNode)621     PostResult visitAggregatePost(TIntermAggregate &callNode) override
622     {
623         if (callNode.isConstructor())
624         {
625             return callNode;
626         }
627         else
628         {
629             const TFunction &oldCalledFunc = *callNode.getFunction();
630             if (!mEnv.isOriginalPipelineFunction(oldCalledFunc))
631             {
632                 return callNode;
633             }
634             const TFunction &newCalledFunc = mEnv.getUpdatedFunction(oldCalledFunc);
635 
636             const TFunction *oldOwnerFunc = getParentFunction();
637             ASSERT(oldOwnerFunc);
638             const TFunction &newOwnerFunc = mEnv.getUpdatedFunction(*oldOwnerFunc);
639 
640             return *TIntermAggregate::CreateFunctionCall(
641                 newCalledFunc, &CloneSequenceAndPrepend(
642                                    *callNode.getSequence(),
643                                    *new TIntermSymbol(&getInternalPipelineVariable(newOwnerFunc))));
644         }
645     }
646 
visitFunctionPrototypePost(TIntermFunctionPrototype & funcProtoNode)647     PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &funcProtoNode) override
648     {
649         TIntermFunctionPrototype *newFuncProtoNode =
650             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
651         if (newFuncProtoNode == nullptr)
652         {
653             return funcProtoNode;
654         }
655         return *newFuncProtoNode;
656     }
657 
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)658     PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
659     {
660         if (funcDefNode.getFunction()->isMain())
661         {
662             return visitMain(funcDefNode);
663         }
664         else
665         {
666             return visitNonMain(funcDefNode);
667         }
668     }
669 
visitNonMain(TIntermFunctionDefinition & funcDefNode)670     TIntermNode &visitNonMain(TIntermFunctionDefinition &funcDefNode)
671     {
672         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
673         ASSERT(!funcProtoNode.getFunction()->isMain());
674 
675         TIntermFunctionPrototype *newFuncProtoNode =
676             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
677         if (newFuncProtoNode == nullptr)
678         {
679             return funcDefNode;
680         }
681 
682         const TFunction &func = *newFuncProtoNode->getFunction();
683         ASSERT(!func.isMain());
684 
685         TIntermBlock *body = funcDefNode.getBody();
686 
687         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
688     }
689 
visitMain(TIntermFunctionDefinition & funcDefNode)690     TIntermNode &visitMain(TIntermFunctionDefinition &funcDefNode)
691     {
692         TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
693         ASSERT(funcProtoNode.getFunction()->isMain());
694 
695         TIntermFunctionPrototype *newFuncProtoNode =
696             mEnv.createUpdatedFunctionPrototype(funcProtoNode);
697         if (newFuncProtoNode == nullptr)
698         {
699             return funcDefNode;
700         }
701 
702         const TFunction &func = *newFuncProtoNode->getFunction();
703         ASSERT(func.isMain());
704 
705         auto callModifiedToOriginal = [&](TIntermBlock &body) {
706             ASSERT(mPipelineMainLocalVar.internal);
707             if (!mPipeline.isPipelineOut())
708             {
709                 ASSERT(mFuncModifiedToOriginal);
710                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
711                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
712                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
713                     *mFuncModifiedToOriginal, new TIntermSequence{m, o}));
714             }
715         };
716 
717         auto callOriginalToModified = [&](TIntermBlock &body) {
718             ASSERT(mPipelineMainLocalVar.internal);
719             if (mPipeline.isPipelineOut())
720             {
721                 ASSERT(mFuncOriginalToModified);
722                 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
723                 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
724                 body.appendStatement(TIntermAggregate::CreateFunctionCall(
725                     *mFuncOriginalToModified, new TIntermSequence{o, m}));
726             }
727         };
728 
729         TIntermBlock *body = funcDefNode.getBody();
730 
731         if (mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
732         {
733             ASSERT(mPipelineMainLocalVar.isTotallyFull());
734 
735             auto *newBody = new TIntermBlock();
736             newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
737 
738             if (mPipeline.type == Pipeline::Type::InvocationVertexGlobals ||
739                 mPipeline.type == Pipeline::Type::InvocationFragmentGlobals)
740             {
741                 // Populate struct instance with references to global pipeline variables.
742                 for (const TField *field : mPipelineStruct.external->fields())
743                 {
744                     auto *var        = new TVariable(&mSymbolTable, field->name(), field->type(),
745                                               field->symbolType());
746                     auto *symbol     = new TIntermSymbol(var);
747                     auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, var->name());
748                     auto *assignNode = new TIntermBinary(TOperator::EOpAssign, &accessNode, symbol);
749                     newBody->appendStatement(assignNode);
750                 }
751             }
752             else if (mPipeline.type == Pipeline::Type::Texture)
753             {
754                 const TFieldList &fields = mPipelineStruct.external->fields();
755 
756                 ASSERT(func.getParamCount() >= 2 * fields.size());
757                 size_t paramIndex = func.getParamCount() - 2 * fields.size();
758 
759                 for (const TField *field : fields)
760                 {
761                     const TVariable &textureParam = *func.getParam(paramIndex++);
762                     const TVariable &samplerParam = *func.getParam(paramIndex++);
763 
764                     auto go = [&](TIntermTyped &env, const int *index) {
765                         TIntermTyped &textureField = AccessField(
766                             AccessIndex(*env.deepCopy(), index), ImmutableString("texture"));
767                         TIntermTyped &samplerField = AccessField(
768                             AccessIndex(*env.deepCopy(), index), ImmutableString("sampler"));
769 
770                         auto mkAssign = [&](TIntermTyped &field, const TVariable &param) {
771                             return new TIntermBinary(TOperator::EOpAssign, &field,
772                                                      &mSymbolEnv.callFunctionOverload(
773                                                          Name("addressof"), field.getType(),
774                                                          *new TIntermSequence{&AccessIndex(
775                                                              *new TIntermSymbol(&param), index)}));
776                         };
777 
778                         newBody->appendStatement(mkAssign(textureField, textureParam));
779                         newBody->appendStatement(mkAssign(samplerField, samplerParam));
780                     };
781 
782                     TIntermTyped &env = AccessField(*mPipelineMainLocalVar.internal, field->name());
783                     const TType &envType = env.getType();
784 
785                     if (envType.isArray())
786                     {
787                         ASSERT(!envType.isArrayOfArrays());
788                         const auto n = static_cast<int>(envType.getArraySizeProduct());
789                         for (int i = 0; i < n; ++i)
790                         {
791                             go(env, &i);
792                         }
793                     }
794                     else
795                     {
796                         go(env, nullptr);
797                     }
798                 }
799             }
800             else if (mPipeline.type == Pipeline::Type::InstanceId)
801             {
802                 newBody->appendStatement(new TIntermBinary(
803                     TOperator::EOpAssign,
804                     &AccessFieldByIndex(*new TIntermSymbol(&getInternalPipelineVariable(func)), 0),
805                     &AsType(mSymbolEnv, *new TType(TBasicType::EbtInt),
806                             *new TIntermSymbol(&getExternalPipelineVariable(func)))));
807             }
808             else if (!mPipelineMainLocalVar.isUniform())
809             {
810                 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.external});
811                 callModifiedToOriginal(*newBody);
812             }
813 
814             newBody->appendStatement(body);
815 
816             if (!mPipelineMainLocalVar.isUniform())
817             {
818                 callOriginalToModified(*newBody);
819             }
820 
821             if (mPipeline.isPipelineOut())
822             {
823                 newBody->appendStatement(new TIntermBranch(
824                     TOperator::EOpReturn, new TIntermSymbol(mPipelineMainLocalVar.external)));
825             }
826 
827             body = newBody;
828         }
829         else if (!mPipelineMainLocalVar.isUniform())
830         {
831             ASSERT(!mPipelineMainLocalVar.external);
832             ASSERT(mPipelineMainLocalVar.internal);
833 
834             auto *newBody = new TIntermBlock();
835             newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
836             callModifiedToOriginal(*newBody);
837             newBody->appendStatement(body);
838             callOriginalToModified(*newBody);
839             body = newBody;
840         }
841 
842         return *new TIntermFunctionDefinition(newFuncProtoNode, body);
843     }
844 };
845 
846 ////////////////////////////////////////////////////////////////////////////////
847 
UpdatePipelineSymbols(Pipeline::Type pipelineType,TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,const VariableSet & pipelineVariables,PipelineScoped<TVariable> pipelineMainLocalVar)848 bool UpdatePipelineSymbols(Pipeline::Type pipelineType,
849                            TCompiler &compiler,
850                            TIntermBlock &root,
851                            SymbolEnv &symbolEnv,
852                            const VariableSet &pipelineVariables,
853                            PipelineScoped<TVariable> pipelineMainLocalVar)
854 {
855     auto map = [&](const TFunction *owner, TIntermSymbol &symbol) -> TIntermNode & {
856         const TVariable &var = symbol.variable();
857         if (pipelineVariables.find(&var) == pipelineVariables.end())
858         {
859             return symbol;
860         }
861         ASSERT(owner);
862         const TVariable *structInstanceVar;
863         if (owner->isMain())
864         {
865             ASSERT(pipelineMainLocalVar.internal);
866             structInstanceVar = pipelineMainLocalVar.internal;
867         }
868         else
869         {
870             ASSERT(owner->getParamCount() > 0);
871             structInstanceVar = owner->getParam(0);
872         }
873         ASSERT(structInstanceVar);
874         return AccessField(*structInstanceVar, var.name());
875     };
876     return MapSymbols(compiler, root, map);
877 }
878 
879 ////////////////////////////////////////////////////////////////////////////////
880 
RewritePipeline(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,Invariants & invariants,PipelineScoped<TStructure> & outStruct)881 bool RewritePipeline(TCompiler &compiler,
882                      TIntermBlock &root,
883                      IdGen &idGen,
884                      const Pipeline &pipeline,
885                      SymbolEnv &symbolEnv,
886                      Invariants &invariants,
887                      PipelineScoped<TStructure> &outStruct)
888 {
889     ASSERT(outStruct.isTotallyEmpty());
890 
891     TSymbolTable &symbolTable = compiler.getSymbolTable();
892 
893     PipelineStructInfo psi;
894     if (!GeneratePipelineStruct::Exec(psi, compiler, root, idGen, pipeline, symbolEnv, invariants))
895     {
896         return false;
897     }
898 
899     if (psi.isEmpty())
900     {
901         return true;
902     }
903 
904     const auto pipelineFunctions = DiscoverDependentFunctions(root, [&](const TVariable &var) {
905         return psi.pipelineVariables.find(&var) != psi.pipelineVariables.end();
906     });
907 
908     auto pipelineMainLocalVar =
909         CreatePipelineMainLocalVar(symbolTable, pipeline, psi.pipelineStruct);
910 
911     if (!UpdatePipelineFunctions::ThreadPipeline(
912             compiler, root, pipeline, pipelineFunctions, psi.pipelineStruct, pipelineMainLocalVar,
913             idGen, symbolEnv, psi.funcOriginalToModified, psi.funcModifiedToOriginal))
914     {
915         return false;
916     }
917 
918     if (!pipeline.globalInstanceVar)
919     {
920         if (!UpdatePipelineSymbols(pipeline.type, compiler, root, symbolEnv, psi.pipelineVariables,
921                                    pipelineMainLocalVar))
922         {
923             return false;
924         }
925     }
926 
927     if (!PruneNoOps(&compiler, &root, &compiler.getSymbolTable()))
928     {
929         return false;
930     }
931 
932     outStruct = psi.pipelineStruct;
933     return true;
934 }
935 
936 }  // anonymous namespace
937 
RewritePipelines(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,DriverUniform & angleUniformsGlobalInstanceVar,SymbolEnv & symbolEnv,Invariants & invariants,PipelineStructs & outStructs)938 bool sh::RewritePipelines(TCompiler &compiler,
939                           TIntermBlock &root,
940                           IdGen &idGen,
941                           DriverUniform &angleUniformsGlobalInstanceVar,
942                           SymbolEnv &symbolEnv,
943                           Invariants &invariants,
944                           PipelineStructs &outStructs)
945 {
946     struct Info
947     {
948         Pipeline::Type pipelineType;
949         PipelineScoped<TStructure> &outStruct;
950         const TVariable *globalInstanceVar;
951     };
952 
953     Info infos[] = {
954         {Pipeline::Type::InstanceId, outStructs.instanceId, nullptr},
955         {Pipeline::Type::Texture, outStructs.texture, nullptr},
956         {Pipeline::Type::NonConstantGlobals, outStructs.nonConstantGlobals, nullptr},
957         {Pipeline::Type::AngleUniforms, outStructs.angleUniforms,
958          angleUniformsGlobalInstanceVar.getDriverUniformsVariable()},
959         {Pipeline::Type::UserUniforms, outStructs.userUniforms, nullptr},
960         {Pipeline::Type::VertexIn, outStructs.vertexIn, nullptr},
961         {Pipeline::Type::VertexOut, outStructs.vertexOut, nullptr},
962         {Pipeline::Type::FragmentIn, outStructs.fragmentIn, nullptr},
963         {Pipeline::Type::FragmentOut, outStructs.fragmentOut, nullptr},
964         {Pipeline::Type::InvocationVertexGlobals, outStructs.invocationVertexGlobals, nullptr},
965         {Pipeline::Type::InvocationFragmentGlobals, outStructs.invocationFragmentGlobals, nullptr},
966         {Pipeline::Type::UniformBuffer, outStructs.uniformBuffers, nullptr},
967     };
968 
969     for (Info &info : infos)
970     {
971         Pipeline pipeline{info.pipelineType, info.globalInstanceVar};
972         if (!RewritePipeline(compiler, root, idGen, pipeline, symbolEnv, invariants,
973                              info.outStruct))
974         {
975             return false;
976         }
977     }
978 
979     return true;
980 }
981