1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/TranslatorMetalDirect/WrapMain.h"
8 #include "compiler/translator/Compiler.h"
9 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
10
11 using namespace sh;
12
13 ////////////////////////////////////////////////////////////////////////////////
14
15 namespace
16 {
17
18 class Wrapper : public TIntermTraverser
19 {
20 private:
21 IdGen &mIdGen;
22
23 public:
Wrapper(TSymbolTable & symbolTable,IdGen & idGen)24 Wrapper(TSymbolTable &symbolTable, IdGen &idGen)
25 : TIntermTraverser(false, false, true, &symbolTable), mIdGen(idGen)
26 {}
27
visitBlock(Visit,TIntermBlock * blockNode)28 bool visitBlock(Visit, TIntermBlock *blockNode) override
29 {
30 if (blockNode != getRootNode())
31 {
32 return true;
33 }
34
35 for (TIntermNode *node : *blockNode->getSequence())
36 {
37 if (TIntermFunctionDefinition *funcDefNode = node->getAsFunctionDefinition())
38 {
39 const TFunction &func = *funcDefNode->getFunction();
40 if (func.isMain())
41 {
42 visitMain(*blockNode, funcDefNode);
43 break;
44 }
45 }
46 }
47
48 return true;
49 }
50
51 private:
visitMain(TIntermBlock & root,TIntermFunctionDefinition * funcDefNode)52 void visitMain(TIntermBlock &root, TIntermFunctionDefinition *funcDefNode)
53 {
54 const TFunction &func = *funcDefNode->getFunction();
55 ASSERT(func.isMain());
56 ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
57 ASSERT(func.getParamCount() == 0);
58
59 const TFunction &externalMainFunc = *funcDefNode->getFunction();
60 const TFunction &internalMainFunc = CloneFunction(*mSymbolTable, mIdGen, externalMainFunc);
61
62 TIntermFunctionPrototype *externalMainProto = funcDefNode->getFunctionPrototype();
63 TIntermFunctionPrototype *internalMainProto =
64 new TIntermFunctionPrototype(&internalMainFunc);
65
66 TIntermBlock *externalMainBody = new TIntermBlock();
67 externalMainBody->appendStatement(
68 TIntermAggregate::CreateFunctionCall(internalMainFunc, new TIntermSequence()));
69
70 TIntermBlock *internalMainBody = funcDefNode->getBody();
71
72 TIntermFunctionDefinition *externalMainDef =
73 new TIntermFunctionDefinition(externalMainProto, externalMainBody);
74 TIntermFunctionDefinition *internalMainDef =
75 new TIntermFunctionDefinition(internalMainProto, internalMainBody);
76
77 mMultiReplacements.push_back(NodeReplaceWithMultipleEntry(
78 &root, funcDefNode, TIntermSequence{internalMainDef, externalMainDef}));
79 }
80 };
81
82 } // namespace
83
WrapMain(TCompiler & compiler,IdGen & idGen,TIntermBlock & root)84 bool sh::WrapMain(TCompiler &compiler, IdGen &idGen, TIntermBlock &root)
85 {
86 TSymbolTable &symbolTable = compiler.getSymbolTable();
87 Wrapper wrapper(symbolTable, idGen);
88 root.traverse(&wrapper);
89 if (!wrapper.updateTree(&compiler, &root))
90 {
91 return false;
92 }
93 return true;
94 }
95