1 // Copyright 2016 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "AnalyzeCallDepth.h"
16 
FunctionNode(TIntermAggregate * node)17 AnalyzeCallDepth::FunctionNode::FunctionNode(TIntermAggregate *node) : node(node)
18 {
19 	visit = PreVisit;
20 	callDepth = 0;
21 }
22 
getName() const23 const TString &AnalyzeCallDepth::FunctionNode::getName() const
24 {
25 	return node->getName();
26 }
27 
addCallee(AnalyzeCallDepth::FunctionNode * callee)28 void AnalyzeCallDepth::FunctionNode::addCallee(AnalyzeCallDepth::FunctionNode *callee)
29 {
30 	for(size_t i = 0; i < callees.size(); i++)
31 	{
32 		if(callees[i] == callee)
33 		{
34 			return;
35 		}
36 	}
37 
38 	callees.push_back(callee);
39 }
40 
analyzeCallDepth(AnalyzeCallDepth * analyzeCallDepth)41 unsigned int AnalyzeCallDepth::FunctionNode::analyzeCallDepth(AnalyzeCallDepth *analyzeCallDepth)
42 {
43 	ASSERT(visit == PreVisit);
44 	ASSERT(analyzeCallDepth);
45 
46 	callDepth = 0;
47 	visit = InVisit;
48 
49 	for(size_t i = 0; i < callees.size(); i++)
50 	{
51 		unsigned int calleeDepth = 0;
52 		switch(callees[i]->visit)
53 		{
54 		case InVisit:
55 			// Cycle detected (recursion)
56 			return UINT_MAX;
57 		case PostVisit:
58 			calleeDepth = callees[i]->getLastDepth();
59 			break;
60 		case PreVisit:
61 			calleeDepth = callees[i]->analyzeCallDepth(analyzeCallDepth);
62 			break;
63 		default:
64 			UNREACHABLE(callees[i]->visit);
65 			break;
66 		}
67 		if(calleeDepth != UINT_MAX) ++calleeDepth;
68 		callDepth = std::max(callDepth, calleeDepth);
69 	}
70 
71 	visit = PostVisit;
72 	return callDepth;
73 }
74 
getLastDepth() const75 unsigned int AnalyzeCallDepth::FunctionNode::getLastDepth() const
76 {
77 	return callDepth;
78 }
79 
removeIfUnreachable()80 void AnalyzeCallDepth::FunctionNode::removeIfUnreachable()
81 {
82 	if(visit == PreVisit)
83 	{
84 		node->setOp(EOpPrototype);
85 		node->getSequence().resize(1);   // Remove function body
86 	}
87 }
88 
AnalyzeCallDepth(TIntermNode * root)89 AnalyzeCallDepth::AnalyzeCallDepth(TIntermNode *root)
90 	: TIntermTraverser(true, false, true, false),
91 	  currentFunction(0)
92 {
93 	root->traverse(this);
94 }
95 
~AnalyzeCallDepth()96 AnalyzeCallDepth::~AnalyzeCallDepth()
97 {
98 	for(size_t i = 0; i < functions.size(); i++)
99 	{
100 		delete functions[i];
101 	}
102 }
103 
visitAggregate(Visit visit,TIntermAggregate * node)104 bool AnalyzeCallDepth::visitAggregate(Visit visit, TIntermAggregate *node)
105 {
106 	switch(node->getOp())
107 	{
108 	case EOpFunction:   // Function definition
109 		{
110 			if(visit == PreVisit)
111 			{
112 				currentFunction = findFunctionByName(node->getName());
113 
114 				if(!currentFunction)
115 				{
116 					currentFunction = new FunctionNode(node);
117 					functions.push_back(currentFunction);
118 				}
119 			}
120 			else if(visit == PostVisit)
121 			{
122 				currentFunction = 0;
123 			}
124 		}
125 		break;
126 	case EOpFunctionCall:
127 		{
128 			if(!node->isUserDefined())
129 			{
130 				return true;   // Check the arguments for function calls
131 			}
132 
133 			if(visit == PreVisit)
134 			{
135 				FunctionNode *function = findFunctionByName(node->getName());
136 
137 				if(!function)
138 				{
139 					function = new FunctionNode(node);
140 					functions.push_back(function);
141 				}
142 
143 				if(currentFunction)
144 				{
145 					currentFunction->addCallee(function);
146 				}
147 				else
148 				{
149 					globalFunctionCalls.insert(function);
150 				}
151 			}
152 		}
153 		break;
154 	default:
155 		break;
156 	}
157 
158 	return true;
159 }
160 
analyzeCallDepth()161 unsigned int AnalyzeCallDepth::analyzeCallDepth()
162 {
163 	FunctionNode *main = findFunctionByName("main(");
164 
165 	if(!main)
166 	{
167 		return 0;
168 	}
169 
170 	unsigned int depth = main->analyzeCallDepth(this);
171 	if(depth != UINT_MAX) ++depth;
172 
173 	for(FunctionSet::iterator globalCall = globalFunctionCalls.begin(); globalCall != globalFunctionCalls.end(); globalCall++)
174 	{
175 		unsigned int globalDepth = (*globalCall)->analyzeCallDepth(this);
176 		if(globalDepth != UINT_MAX) ++globalDepth;
177 
178 		if(globalDepth > depth)
179 		{
180 			depth = globalDepth;
181 		}
182 	}
183 
184 	for(size_t i = 0; i < functions.size(); i++)
185 	{
186 		functions[i]->removeIfUnreachable();
187 	}
188 
189 	return depth;
190 }
191 
findFunctionByName(const TString & name)192 AnalyzeCallDepth::FunctionNode *AnalyzeCallDepth::findFunctionByName(const TString &name)
193 {
194 	for(size_t i = 0; i < functions.size(); i++)
195 	{
196 		if(functions[i]->getName() == name)
197 		{
198 			return functions[i];
199 		}
200 	}
201 
202 	return 0;
203 }
204 
205