1 // Copyright 2016 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "AnalyzeCallDepth.h"
16
FunctionNode(TIntermAggregate * node)17 AnalyzeCallDepth::FunctionNode::FunctionNode(TIntermAggregate *node) : node(node)
18 {
19 visit = PreVisit;
20 callDepth = 0;
21 }
22
getName() const23 const TString &AnalyzeCallDepth::FunctionNode::getName() const
24 {
25 return node->getName();
26 }
27
addCallee(AnalyzeCallDepth::FunctionNode * callee)28 void AnalyzeCallDepth::FunctionNode::addCallee(AnalyzeCallDepth::FunctionNode *callee)
29 {
30 for(size_t i = 0; i < callees.size(); i++)
31 {
32 if(callees[i] == callee)
33 {
34 return;
35 }
36 }
37
38 callees.push_back(callee);
39 }
40
analyzeCallDepth(AnalyzeCallDepth * analyzeCallDepth)41 unsigned int AnalyzeCallDepth::FunctionNode::analyzeCallDepth(AnalyzeCallDepth *analyzeCallDepth)
42 {
43 ASSERT(visit == PreVisit);
44 ASSERT(analyzeCallDepth);
45
46 callDepth = 0;
47 visit = InVisit;
48
49 for(size_t i = 0; i < callees.size(); i++)
50 {
51 unsigned int calleeDepth = 0;
52 switch(callees[i]->visit)
53 {
54 case InVisit:
55 // Cycle detected (recursion)
56 return UINT_MAX;
57 case PostVisit:
58 calleeDepth = callees[i]->getLastDepth();
59 break;
60 case PreVisit:
61 calleeDepth = callees[i]->analyzeCallDepth(analyzeCallDepth);
62 break;
63 default:
64 UNREACHABLE(callees[i]->visit);
65 break;
66 }
67 if(calleeDepth != UINT_MAX) ++calleeDepth;
68 callDepth = std::max(callDepth, calleeDepth);
69 }
70
71 visit = PostVisit;
72 return callDepth;
73 }
74
getLastDepth() const75 unsigned int AnalyzeCallDepth::FunctionNode::getLastDepth() const
76 {
77 return callDepth;
78 }
79
removeIfUnreachable()80 void AnalyzeCallDepth::FunctionNode::removeIfUnreachable()
81 {
82 if(visit == PreVisit)
83 {
84 node->setOp(EOpPrototype);
85 node->getSequence().resize(1); // Remove function body
86 }
87 }
88
AnalyzeCallDepth(TIntermNode * root)89 AnalyzeCallDepth::AnalyzeCallDepth(TIntermNode *root)
90 : TIntermTraverser(true, false, true, false),
91 currentFunction(0)
92 {
93 root->traverse(this);
94 }
95
~AnalyzeCallDepth()96 AnalyzeCallDepth::~AnalyzeCallDepth()
97 {
98 for(size_t i = 0; i < functions.size(); i++)
99 {
100 delete functions[i];
101 }
102 }
103
visitAggregate(Visit visit,TIntermAggregate * node)104 bool AnalyzeCallDepth::visitAggregate(Visit visit, TIntermAggregate *node)
105 {
106 switch(node->getOp())
107 {
108 case EOpFunction: // Function definition
109 {
110 if(visit == PreVisit)
111 {
112 currentFunction = findFunctionByName(node->getName());
113
114 if(!currentFunction)
115 {
116 currentFunction = new FunctionNode(node);
117 functions.push_back(currentFunction);
118 }
119 }
120 else if(visit == PostVisit)
121 {
122 currentFunction = 0;
123 }
124 }
125 break;
126 case EOpFunctionCall:
127 {
128 if(!node->isUserDefined())
129 {
130 return true; // Check the arguments for function calls
131 }
132
133 if(visit == PreVisit)
134 {
135 FunctionNode *function = findFunctionByName(node->getName());
136
137 if(!function)
138 {
139 function = new FunctionNode(node);
140 functions.push_back(function);
141 }
142
143 if(currentFunction)
144 {
145 currentFunction->addCallee(function);
146 }
147 else
148 {
149 globalFunctionCalls.insert(function);
150 }
151 }
152 }
153 break;
154 default:
155 break;
156 }
157
158 return true;
159 }
160
analyzeCallDepth()161 unsigned int AnalyzeCallDepth::analyzeCallDepth()
162 {
163 FunctionNode *main = findFunctionByName("main(");
164
165 if(!main)
166 {
167 return 0;
168 }
169
170 unsigned int depth = main->analyzeCallDepth(this);
171 if(depth != UINT_MAX) ++depth;
172
173 for(FunctionSet::iterator globalCall = globalFunctionCalls.begin(); globalCall != globalFunctionCalls.end(); globalCall++)
174 {
175 unsigned int globalDepth = (*globalCall)->analyzeCallDepth(this);
176 if(globalDepth != UINT_MAX) ++globalDepth;
177
178 if(globalDepth > depth)
179 {
180 depth = globalDepth;
181 }
182 }
183
184 for(size_t i = 0; i < functions.size(); i++)
185 {
186 functions[i]->removeIfUnreachable();
187 }
188
189 return depth;
190 }
191
findFunctionByName(const TString & name)192 AnalyzeCallDepth::FunctionNode *AnalyzeCallDepth::findFunctionByName(const TString &name)
193 {
194 for(size_t i = 0; i < functions.size(); i++)
195 {
196 if(functions[i]->getName() == name)
197 {
198 return functions[i];
199 }
200 }
201
202 return 0;
203 }
204
205