1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <algorithm>
8 #include <unordered_map>
9
10 #include "compiler/translator/TranslatorMetalDirect.h"
11 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
12 #include "compiler/translator/TranslatorMetalDirect/IntermRebuild.h"
13 #include "compiler/translator/TranslatorMetalDirect/ReduceInterfaceBlocks.h"
14 #include "compiler/translator/tree_ops/SeparateDeclarations.h"
15
16 using namespace sh;
17
18 ////////////////////////////////////////////////////////////////////////////////
19
20 namespace
21 {
22
23 class Reducer : public TIntermRebuild
24 {
25 std::unordered_map<const TInterfaceBlock *, const TVariable *> mLiftedMap;
26 std::unordered_map<const TVariable *, const TVariable *> mInstanceMap;
27 IdGen &mIdGen;
28
29 public:
Reducer(TCompiler & compiler,IdGen & idGen)30 Reducer(TCompiler &compiler, IdGen &idGen)
31 : TIntermRebuild(compiler, true, false), mIdGen(idGen)
32 {}
33
visitDeclarationPre(TIntermDeclaration & declNode)34 PreResult visitDeclarationPre(TIntermDeclaration &declNode) override
35 {
36 ASSERT(declNode.getChildCount() == 1);
37 TIntermNode &node = *declNode.getChildNode(0);
38
39 if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
40 {
41 const TVariable &var = symbolNode->variable();
42 const TType &type = var.getType();
43 const SymbolType symbolType = var.symbolType();
44 if (const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock())
45 {
46 if (symbolType == SymbolType::Empty)
47 {
48 // Create instance variable
49 auto &structure =
50 *new TStructure(&mSymbolTable, interfaceBlock->name(),
51 &interfaceBlock->fields(), interfaceBlock->symbolType());
52 auto &structVar = CreateStructTypeVariable(mSymbolTable, structure);
53
54 auto &instanceVar = CreateInstanceVariable(
55 mSymbolTable, structure, mIdGen.createNewName(interfaceBlock->name()),
56 TQualifier::EvqBuffer, &type.getArraySizes());
57 mLiftedMap[interfaceBlock] = &instanceVar;
58
59 TIntermNode *replacements[] = {
60 new TIntermDeclaration{new TIntermSymbol(&structVar)},
61 new TIntermDeclaration{new TIntermSymbol(&instanceVar)}};
62 return PreResult::Multi(std::begin(replacements), std::end(replacements));
63 }
64 else
65 {
66 ASSERT(type.getQualifier() == TQualifier::EvqUniform);
67
68 auto &structure =
69 *new TStructure(&mSymbolTable, interfaceBlock->name(),
70 &interfaceBlock->fields(), interfaceBlock->symbolType());
71 auto &structVar = CreateStructTypeVariable(mSymbolTable, structure);
72 auto &instanceVar =
73 CreateInstanceVariable(mSymbolTable, structure, Name(var),
74 TQualifier::EvqBuffer, &type.getArraySizes());
75
76 mInstanceMap[&var] = &instanceVar;
77
78 TIntermNode *replacements[] = {
79 new TIntermDeclaration{new TIntermSymbol(&structVar)},
80 new TIntermDeclaration{new TIntermSymbol(&instanceVar)}};
81 return PreResult::Multi(std::begin(replacements), std::end(replacements));
82 }
83 }
84 }
85
86 return {declNode, VisitBits::Both};
87 }
88
visitSymbolPre(TIntermSymbol & symbolNode)89 PreResult visitSymbolPre(TIntermSymbol &symbolNode) override
90 {
91 const TVariable &var = symbolNode.variable();
92 {
93 auto it = mInstanceMap.find(&var);
94 if (it != mInstanceMap.end())
95 {
96 return *new TIntermSymbol(it->second);
97 }
98 }
99 if (const TInterfaceBlock *ib = var.getType().getInterfaceBlock())
100 {
101 auto it = mLiftedMap.find(ib);
102 if (it != mLiftedMap.end())
103 {
104 return AccessField(*(it->second), var.name());
105 }
106 }
107 return symbolNode;
108 }
109 };
110
111 } // anonymous namespace
112
113 ////////////////////////////////////////////////////////////////////////////////
114
ReduceInterfaceBlocks(TCompiler & compiler,TIntermBlock & root,IdGen & idGen)115 bool sh::ReduceInterfaceBlocks(TCompiler &compiler, TIntermBlock &root, IdGen &idGen)
116 {
117 Reducer reducer(compiler, idGen);
118 if (!reducer.rebuildRoot(root))
119 {
120 return false;
121 }
122
123 if (!SeparateDeclarations(&compiler, &root))
124 {
125 return false;
126 }
127
128 return true;
129 }
130