1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <unordered_map>
9 
10 #include "compiler/translator/TranslatorMetalDirect.h"
11 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
12 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
13 #include "compiler/translator/TranslatorMetalDirect/ReduceInterfaceBlocks.h"
14 #include "compiler/translator/tree_ops/SeparateDeclarations.h"
15 
16 using namespace sh;
17 
18 ////////////////////////////////////////////////////////////////////////////////
19 
20 namespace
21 {
22 
23 class Reducer : public TIntermRebuild
24 {
25     std::unordered_map<const TInterfaceBlock *, const TVariable *> mLiftedMap;
26     std::unordered_map<const TVariable *, const TVariable *> mInstanceMap;
27     IdGen &mIdGen;
28 
29   public:
Reducer(TCompiler & compiler,IdGen & idGen)30     Reducer(TCompiler &compiler, IdGen &idGen)
31         : TIntermRebuild(compiler, true, false), mIdGen(idGen)
32     {}
33 
visitDeclarationPre(TIntermDeclaration & declNode)34     PreResult visitDeclarationPre(TIntermDeclaration &declNode) override
35     {
36         ASSERT(declNode.getChildCount() == 1);
37         TIntermNode &node = *declNode.getChildNode(0);
38 
39         if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
40         {
41             const TVariable &var        = symbolNode->variable();
42             const TType &type           = var.getType();
43             const SymbolType symbolType = var.symbolType();
44             if (const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock())
45             {
46                 if (symbolType == SymbolType::Empty)
47                 {
48                     // Create instance variable
49                     auto &structure =
50                         *new TStructure(&mSymbolTable, interfaceBlock->name(),
51                                         &interfaceBlock->fields(), interfaceBlock->symbolType());
52                     auto &structVar = CreateStructTypeVariable(mSymbolTable, structure);
53 
54                     auto &instanceVar = CreateInstanceVariable(
55                         mSymbolTable, structure, mIdGen.createNewName(interfaceBlock->name()),
56                         TQualifier::EvqBuffer, &type.getArraySizes());
57                     mLiftedMap[interfaceBlock] = &instanceVar;
58 
59                     TIntermNode *replacements[] = {
60                         new TIntermDeclaration{new TIntermSymbol(&structVar)},
61                         new TIntermDeclaration{new TIntermSymbol(&instanceVar)}};
62                     return PreResult::Multi(std::begin(replacements), std::end(replacements));
63                 }
64                 else
65                 {
66                     ASSERT(type.getQualifier() == TQualifier::EvqUniform);
67 
68                     auto &structure =
69                         *new TStructure(&mSymbolTable, interfaceBlock->name(),
70                                         &interfaceBlock->fields(), interfaceBlock->symbolType());
71                     auto &structVar = CreateStructTypeVariable(mSymbolTable, structure);
72                     auto &instanceVar =
73                         CreateInstanceVariable(mSymbolTable, structure, Name(var),
74                                                TQualifier::EvqBuffer, &type.getArraySizes());
75 
76                     mInstanceMap[&var] = &instanceVar;
77 
78                     TIntermNode *replacements[] = {
79                         new TIntermDeclaration{new TIntermSymbol(&structVar)},
80                         new TIntermDeclaration{new TIntermSymbol(&instanceVar)}};
81                     return PreResult::Multi(std::begin(replacements), std::end(replacements));
82                 }
83             }
84         }
85 
86         return {declNode, VisitBits::Both};
87     }
88 
visitSymbolPre(TIntermSymbol & symbolNode)89     PreResult visitSymbolPre(TIntermSymbol &symbolNode) override
90     {
91         const TVariable &var = symbolNode.variable();
92         {
93             auto it = mInstanceMap.find(&var);
94             if (it != mInstanceMap.end())
95             {
96                 return *new TIntermSymbol(it->second);
97             }
98         }
99         if (const TInterfaceBlock *ib = var.getType().getInterfaceBlock())
100         {
101             auto it = mLiftedMap.find(ib);
102             if (it != mLiftedMap.end())
103             {
104                 return AccessField(*(it->second), var.name());
105             }
106         }
107         return symbolNode;
108     }
109 };
110 
111 }  // anonymous namespace
112 
113 ////////////////////////////////////////////////////////////////////////////////
114 
ReduceInterfaceBlocks(TCompiler & compiler,TIntermBlock & root,IdGen & idGen)115 bool sh::ReduceInterfaceBlocks(TCompiler &compiler, TIntermBlock &root, IdGen &idGen)
116 {
117     Reducer reducer(compiler, idGen);
118     if (!reducer.rebuildRoot(root))
119     {
120         return false;
121     }
122 
123     if (!SeparateDeclarations(&compiler, &root))
124     {
125         return false;
126     }
127 
128     return true;
129 }
130