1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/TranslatorMetalDirect/WrapMain.h"
8 #include "compiler/translator/Compiler.h"
9 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
10 
11 using namespace sh;
12 
13 ////////////////////////////////////////////////////////////////////////////////
14 
15 namespace
16 {
17 
18 class Wrapper : public TIntermTraverser
19 {
20   private:
21     IdGen &mIdGen;
22 
23   public:
Wrapper(TSymbolTable & symbolTable,IdGen & idGen)24     Wrapper(TSymbolTable &symbolTable, IdGen &idGen)
25         : TIntermTraverser(false, false, true, &symbolTable), mIdGen(idGen)
26     {}
27 
visitBlock(Visit,TIntermBlock * blockNode)28     bool visitBlock(Visit, TIntermBlock *blockNode) override
29     {
30         if (blockNode != getRootNode())
31         {
32             return true;
33         }
34 
35         for (TIntermNode *node : *blockNode->getSequence())
36         {
37             if (TIntermFunctionDefinition *funcDefNode = node->getAsFunctionDefinition())
38             {
39                 const TFunction &func = *funcDefNode->getFunction();
40                 if (func.isMain())
41                 {
42                     visitMain(*blockNode, funcDefNode);
43                     break;
44                 }
45             }
46         }
47 
48         return true;
49     }
50 
51   private:
visitMain(TIntermBlock & root,TIntermFunctionDefinition * funcDefNode)52     void visitMain(TIntermBlock &root, TIntermFunctionDefinition *funcDefNode)
53     {
54         const TFunction &func = *funcDefNode->getFunction();
55         ASSERT(func.isMain());
56         ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
57         ASSERT(func.getParamCount() == 0);
58 
59         const TFunction &externalMainFunc = *funcDefNode->getFunction();
60         const TFunction &internalMainFunc = CloneFunction(*mSymbolTable, mIdGen, externalMainFunc);
61 
62         TIntermFunctionPrototype *externalMainProto = funcDefNode->getFunctionPrototype();
63         TIntermFunctionPrototype *internalMainProto =
64             new TIntermFunctionPrototype(&internalMainFunc);
65 
66         TIntermBlock *externalMainBody = new TIntermBlock();
67         externalMainBody->appendStatement(
68             TIntermAggregate::CreateFunctionCall(internalMainFunc, new TIntermSequence()));
69 
70         TIntermBlock *internalMainBody = funcDefNode->getBody();
71 
72         TIntermFunctionDefinition *externalMainDef =
73             new TIntermFunctionDefinition(externalMainProto, externalMainBody);
74         TIntermFunctionDefinition *internalMainDef =
75             new TIntermFunctionDefinition(internalMainProto, internalMainBody);
76 
77         mMultiReplacements.push_back(NodeReplaceWithMultipleEntry(
78             &root, funcDefNode, TIntermSequence{internalMainDef, externalMainDef}));
79     }
80 };
81 
82 }  // namespace
83 
WrapMain(TCompiler & compiler,IdGen & idGen,TIntermBlock & root)84 bool sh::WrapMain(TCompiler &compiler, IdGen &idGen, TIntermBlock &root)
85 {
86     TSymbolTable &symbolTable = compiler.getSymbolTable();
87     Wrapper wrapper(symbolTable, idGen);
88     root.traverse(&wrapper);
89     if (!wrapper.updateTree(&compiler, &root))
90     {
91         return false;
92     }
93     return true;
94 }
95