1 // Copyright 2016 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "AnalyzeCallDepth.h"
16 
17 static TIntermSequence::iterator
traverseCaseBody(AnalyzeCallDepth * analysis,TIntermSequence::iterator & start,const TIntermSequence::iterator & end)18 traverseCaseBody(AnalyzeCallDepth* analysis,
19 				 TIntermSequence::iterator& start,
20 				 const TIntermSequence::iterator& end) {
21 	TIntermSequence::iterator current = start;
22 	for (++current; current != end; ++current)
23 	{
24 		(*current)->traverse(analysis);
25 		if((*current)->getAsBranchNode()) // Kill, Break, Continue or Return
26 		{
27 			break;
28 		}
29 	}
30 	return current;
31 }
32 
33 
FunctionNode(TIntermAggregate * node)34 AnalyzeCallDepth::FunctionNode::FunctionNode(TIntermAggregate *node) : node(node)
35 {
36 	visit = PreVisit;
37 	callDepth = 0;
38 }
39 
getName() const40 const TString &AnalyzeCallDepth::FunctionNode::getName() const
41 {
42 	return node->getName();
43 }
44 
addCallee(AnalyzeCallDepth::FunctionNode * callee)45 void AnalyzeCallDepth::FunctionNode::addCallee(AnalyzeCallDepth::FunctionNode *callee)
46 {
47 	for(size_t i = 0; i < callees.size(); i++)
48 	{
49 		if(callees[i] == callee)
50 		{
51 			return;
52 		}
53 	}
54 
55 	callees.push_back(callee);
56 }
57 
analyzeCallDepth(AnalyzeCallDepth * analyzeCallDepth)58 unsigned int AnalyzeCallDepth::FunctionNode::analyzeCallDepth(AnalyzeCallDepth *analyzeCallDepth)
59 {
60 	ASSERT(visit == PreVisit);
61 	ASSERT(analyzeCallDepth);
62 
63 	callDepth = 0;
64 	visit = InVisit;
65 
66 	for(size_t i = 0; i < callees.size(); i++)
67 	{
68 		unsigned int calleeDepth = 0;
69 		switch(callees[i]->visit)
70 		{
71 		case InVisit:
72 			// Cycle detected (recursion)
73 			return UINT_MAX;
74 		case PostVisit:
75 			calleeDepth = callees[i]->getLastDepth();
76 			break;
77 		case PreVisit:
78 			calleeDepth = callees[i]->analyzeCallDepth(analyzeCallDepth);
79 			break;
80 		default:
81 			UNREACHABLE(callees[i]->visit);
82 			break;
83 		}
84 		if(calleeDepth != UINT_MAX) ++calleeDepth;
85 		callDepth = std::max(callDepth, calleeDepth);
86 	}
87 
88 	visit = PostVisit;
89 	return callDepth;
90 }
91 
getLastDepth() const92 unsigned int AnalyzeCallDepth::FunctionNode::getLastDepth() const
93 {
94 	return callDepth;
95 }
96 
removeIfUnreachable()97 void AnalyzeCallDepth::FunctionNode::removeIfUnreachable()
98 {
99 	if(visit == PreVisit)
100 	{
101 		node->setOp(EOpPrototype);
102 		node->getSequence().resize(1);   // Remove function body
103 	}
104 }
105 
AnalyzeCallDepth(TIntermNode * root)106 AnalyzeCallDepth::AnalyzeCallDepth(TIntermNode *root)
107 	: TIntermTraverser(true, false, true, false),
108 	  currentFunction(0)
109 {
110 	root->traverse(this);
111 }
112 
~AnalyzeCallDepth()113 AnalyzeCallDepth::~AnalyzeCallDepth()
114 {
115 	for(size_t i = 0; i < functions.size(); i++)
116 	{
117 		delete functions[i];
118 	}
119 }
120 
visitSwitch(Visit visit,TIntermSwitch * node)121 bool AnalyzeCallDepth::visitSwitch(Visit visit, TIntermSwitch *node)
122 {
123 	TIntermTyped* switchValue = node->getInit();
124 	TIntermAggregate* opList = node->getStatementList();
125 
126 	if(!switchValue || !opList)
127 	{
128 		return false;
129 	}
130 
131 	// TODO: We need to dig into switch statement cases from
132 	// visitSwitch for all traversers. Is there a way to
133 	// preserve existing functionality while moving the iteration
134 	// to the general traverser?
135 	TIntermSequence& sequence = opList->getSequence();
136 	TIntermSequence::iterator it = sequence.begin();
137 	TIntermSequence::iterator defaultIt = sequence.end();
138 	for(; it != sequence.end(); ++it)
139 	{
140 		TIntermCase* currentCase = (*it)->getAsCaseNode();
141 		if(currentCase)
142 		{
143 			TIntermSequence::iterator caseIt = it;
144 			TIntermTyped* condition = currentCase->getCondition();
145 			if(condition) // non default case
146 			{
147 				condition->traverse(this);
148 				traverseCaseBody(this, caseIt, sequence.end());
149 			}
150 			else
151 			{
152 				defaultIt = it; // The default case might not be the last case, keep it for last
153 			}
154 		}
155 	}
156 
157 	// If there's a default case, traverse it here
158 	if(defaultIt != sequence.end())
159 	{
160 		traverseCaseBody(this, defaultIt, sequence.end());
161 	}
162 	return false;
163 }
164 
visitAggregate(Visit visit,TIntermAggregate * node)165 bool AnalyzeCallDepth::visitAggregate(Visit visit, TIntermAggregate *node)
166 {
167 	switch(node->getOp())
168 	{
169 	case EOpFunction:   // Function definition
170 		{
171 			if(visit == PreVisit)
172 			{
173 				currentFunction = findFunctionByName(node->getName());
174 
175 				if(!currentFunction)
176 				{
177 					currentFunction = new FunctionNode(node);
178 					functions.push_back(currentFunction);
179 				}
180 			}
181 			else if(visit == PostVisit)
182 			{
183 				currentFunction = 0;
184 			}
185 		}
186 		break;
187 	case EOpFunctionCall:
188 		{
189 			if(!node->isUserDefined())
190 			{
191 				return true;   // Check the arguments for function calls
192 			}
193 
194 			if(visit == PreVisit)
195 			{
196 				FunctionNode *function = findFunctionByName(node->getName());
197 
198 				if(!function)
199 				{
200 					function = new FunctionNode(node);
201 					functions.push_back(function);
202 				}
203 
204 				if(currentFunction)
205 				{
206 					currentFunction->addCallee(function);
207 				}
208 				else
209 				{
210 					globalFunctionCalls.insert(function);
211 				}
212 			}
213 		}
214 		break;
215 	default:
216 		break;
217 	}
218 
219 	return true;
220 }
221 
analyzeCallDepth()222 unsigned int AnalyzeCallDepth::analyzeCallDepth()
223 {
224 	FunctionNode *main = findFunctionByName("main(");
225 
226 	if(!main)
227 	{
228 		return 0;
229 	}
230 
231 	unsigned int depth = main->analyzeCallDepth(this);
232 	if(depth != UINT_MAX) ++depth;
233 
234 	for(FunctionSet::iterator globalCall = globalFunctionCalls.begin(); globalCall != globalFunctionCalls.end(); globalCall++)
235 	{
236 		unsigned int globalDepth = (*globalCall)->analyzeCallDepth(this);
237 		if(globalDepth != UINT_MAX) ++globalDepth;
238 
239 		if(globalDepth > depth)
240 		{
241 			depth = globalDepth;
242 		}
243 	}
244 
245 	for(size_t i = 0; i < functions.size(); i++)
246 	{
247 		functions[i]->removeIfUnreachable();
248 	}
249 
250 	return depth;
251 }
252 
findFunctionByName(const TString & name)253 AnalyzeCallDepth::FunctionNode *AnalyzeCallDepth::findFunctionByName(const TString &name)
254 {
255 	for(size_t i = 0; i < functions.size(); i++)
256 	{
257 		if(functions[i]->getName() == name)
258 		{
259 			return functions[i];
260 		}
261 	}
262 
263 	return 0;
264 }
265 
266