1 // Copyright 2016 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "ValidateLimitations.h"
16 #include "InfoSink.h"
17 #include "InitializeParseContext.h"
18 #include "ParseHelper.h"
19 
20 namespace {
IsLoopIndex(const TIntermSymbol * symbol,const TLoopStack & stack)21 bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) {
22 	for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) {
23 		if (i->index.id == symbol->getId())
24 			return true;
25 	}
26 	return false;
27 }
28 
MarkLoopForUnroll(const TIntermSymbol * symbol,TLoopStack & stack)29 void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
30 	for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
31 		if (i->index.id == symbol->getId()) {
32 			ASSERT(i->loop);
33 			i->loop->setUnrollFlag(true);
34 			return;
35 		}
36 	}
37 	UNREACHABLE(0);
38 }
39 
40 // Traverses a node to check if it represents a constant index expression.
41 // Definition:
42 // constant-index-expressions are a superset of constant-expressions.
43 // Constant-index-expressions can include loop indices as defined in
44 // GLSL ES 1.0 spec, Appendix A, section 4.
45 // The following are constant-index-expressions:
46 // - Constant expressions
47 // - Loop indices as defined in section 4
48 // - Expressions composed of both of the above
49 class ValidateConstIndexExpr : public TIntermTraverser {
50 public:
ValidateConstIndexExpr(const TLoopStack & stack)51 	ValidateConstIndexExpr(const TLoopStack& stack)
52 		: mValid(true), mLoopStack(stack) {}
53 
54 	// Returns true if the parsed node represents a constant index expression.
isValid() const55 	bool isValid() const { return mValid; }
56 
visitSymbol(TIntermSymbol * symbol)57 	virtual void visitSymbol(TIntermSymbol* symbol) {
58 		// Only constants and loop indices are allowed in a
59 		// constant index expression.
60 		if (mValid) {
61 			mValid = (symbol->getQualifier() == EvqConstExpr) ||
62 			         IsLoopIndex(symbol, mLoopStack);
63 		}
64 	}
65 
66 private:
67 	bool mValid;
68 	const TLoopStack& mLoopStack;
69 };
70 
71 // Traverses a node to check if it uses a loop index.
72 // If an int loop index is used in its body as a sampler array index,
73 // mark the loop for unroll.
74 class ValidateLoopIndexExpr : public TIntermTraverser {
75 public:
ValidateLoopIndexExpr(TLoopStack & stack)76 	ValidateLoopIndexExpr(TLoopStack& stack)
77 		: mUsesFloatLoopIndex(false),
78 		  mUsesIntLoopIndex(false),
79 		  mLoopStack(stack) {}
80 
usesFloatLoopIndex() const81 	bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
usesIntLoopIndex() const82 	bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
83 
visitSymbol(TIntermSymbol * symbol)84 	virtual void visitSymbol(TIntermSymbol* symbol) {
85 		if (IsLoopIndex(symbol, mLoopStack)) {
86 			switch (symbol->getBasicType()) {
87 			case EbtFloat:
88 				mUsesFloatLoopIndex = true;
89 				break;
90 			case EbtUInt:
91 				mUsesIntLoopIndex = true;
92 				MarkLoopForUnroll(symbol, mLoopStack);
93 				break;
94 			case EbtInt:
95 				mUsesIntLoopIndex = true;
96 				MarkLoopForUnroll(symbol, mLoopStack);
97 				break;
98 			default:
99 				UNREACHABLE(symbol->getBasicType());
100 			}
101 		}
102 	}
103 
104 private:
105 	bool mUsesFloatLoopIndex;
106 	bool mUsesIntLoopIndex;
107 	TLoopStack& mLoopStack;
108 };
109 }  // namespace
110 
ValidateLimitations(GLenum shaderType,TInfoSinkBase & sink)111 ValidateLimitations::ValidateLimitations(GLenum shaderType,
112                                          TInfoSinkBase& sink)
113 	: mShaderType(shaderType),
114 	  mSink(sink),
115 	  mNumErrors(0)
116 {
117 }
118 
visitBinary(Visit,TIntermBinary * node)119 bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node)
120 {
121 	// Check if loop index is modified in the loop body.
122 	validateOperation(node, node->getLeft());
123 
124 	// Check indexing.
125 	switch (node->getOp()) {
126 	case EOpIndexDirect:
127 		validateIndexing(node);
128 		break;
129 	case EOpIndexIndirect:
130 		validateIndexing(node);
131 		break;
132 	default: break;
133 	}
134 	return true;
135 }
136 
visitUnary(Visit,TIntermUnary * node)137 bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node)
138 {
139 	// Check if loop index is modified in the loop body.
140 	validateOperation(node, node->getOperand());
141 
142 	return true;
143 }
144 
visitAggregate(Visit,TIntermAggregate * node)145 bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node)
146 {
147 	switch (node->getOp()) {
148 	case EOpFunctionCall:
149 		validateFunctionCall(node);
150 		break;
151 	default:
152 		break;
153 	}
154 	return true;
155 }
156 
visitLoop(Visit,TIntermLoop * node)157 bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
158 {
159 	if (!validateLoopType(node))
160 		return false;
161 
162 	TLoopInfo info;
163 	memset(&info, 0, sizeof(TLoopInfo));
164 	info.loop = node;
165 	if (!validateForLoopHeader(node, &info))
166 		return false;
167 
168 	TIntermNode* body = node->getBody();
169 	if (body) {
170 		mLoopStack.push_back(info);
171 		body->traverse(this);
172 		mLoopStack.pop_back();
173 	}
174 
175 	// The loop is fully processed - no need to visit children.
176 	return false;
177 }
178 
error(TSourceLoc loc,const char * reason,const char * token)179 void ValidateLimitations::error(TSourceLoc loc,
180                                 const char *reason, const char* token)
181 {
182 	mSink.prefix(EPrefixError);
183 	mSink.location(loc);
184 	mSink << "'" << token << "' : " << reason << "\n";
185 	++mNumErrors;
186 }
187 
withinLoopBody() const188 bool ValidateLimitations::withinLoopBody() const
189 {
190 	return !mLoopStack.empty();
191 }
192 
isLoopIndex(const TIntermSymbol * symbol) const193 bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const
194 {
195 	return IsLoopIndex(symbol, mLoopStack);
196 }
197 
validateLoopType(TIntermLoop * node)198 bool ValidateLimitations::validateLoopType(TIntermLoop* node) {
199 	TLoopType type = node->getType();
200 	if (type == ELoopFor)
201 		return true;
202 
203 	// Reject while and do-while loops.
204 	error(node->getLine(),
205 		  "This type of loop is not allowed",
206 		  type == ELoopWhile ? "while" : "do");
207 	return false;
208 }
209 
validateForLoopHeader(TIntermLoop * node,TLoopInfo * info)210 bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node,
211                                                 TLoopInfo* info)
212 {
213 	ASSERT(node->getType() == ELoopFor);
214 
215 	//
216 	// The for statement has the form:
217 	//    for ( init-declaration ; condition ; expression ) statement
218 	//
219 	if (!validateForLoopInit(node, info))
220 		return false;
221 	if (!validateForLoopCond(node, info))
222 		return false;
223 	if (!validateForLoopExpr(node, info))
224 		return false;
225 
226 	return true;
227 }
228 
validateForLoopInit(TIntermLoop * node,TLoopInfo * info)229 bool ValidateLimitations::validateForLoopInit(TIntermLoop* node,
230                                               TLoopInfo* info)
231 {
232 	TIntermNode* init = node->getInit();
233 	if (!init) {
234 		error(node->getLine(), "Missing init declaration", "for");
235 		return false;
236 	}
237 
238 	//
239 	// init-declaration has the form:
240 	//     type-specifier identifier = constant-expression
241 	//
242 	TIntermAggregate* decl = init->getAsAggregate();
243 	if (!decl || (decl->getOp() != EOpDeclaration)) {
244 		error(init->getLine(), "Invalid init declaration", "for");
245 		return false;
246 	}
247 	// To keep things simple do not allow declaration list.
248 	TIntermSequence& declSeq = decl->getSequence();
249 	if (declSeq.size() != 1) {
250 		error(decl->getLine(), "Invalid init declaration", "for");
251 		return false;
252 	}
253 	TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
254 	if (!declInit || (declInit->getOp() != EOpInitialize)) {
255 		error(decl->getLine(), "Invalid init declaration", "for");
256 		return false;
257 	}
258 	TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
259 	if (!symbol) {
260 		error(declInit->getLine(), "Invalid init declaration", "for");
261 		return false;
262 	}
263 	// The loop index has type int or float.
264 	TBasicType type = symbol->getBasicType();
265 	if (!IsInteger(type) && (type != EbtFloat)) {
266 		error(symbol->getLine(),
267 			  "Invalid type for loop index", getBasicString(type));
268 		return false;
269 	}
270 	// The loop index is initialized with constant expression.
271 	if (!isConstExpr(declInit->getRight())) {
272 		error(declInit->getLine(),
273 			  "Loop index cannot be initialized with non-constant expression",
274 			  symbol->getSymbol().c_str());
275 		return false;
276 	}
277 
278 	info->index.id = symbol->getId();
279 	return true;
280 }
281 
validateForLoopCond(TIntermLoop * node,TLoopInfo * info)282 bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
283                                               TLoopInfo* info)
284 {
285 	TIntermNode* cond = node->getCondition();
286 	if (!cond) {
287 		error(node->getLine(), "Missing condition", "for");
288 		return false;
289 	}
290 	//
291 	// condition has the form:
292 	//     loop_index relational_operator constant_expression
293 	//
294 	TIntermBinary* binOp = cond->getAsBinaryNode();
295 	if (!binOp) {
296 		error(node->getLine(), "Invalid condition", "for");
297 		return false;
298 	}
299 	// Loop index should be to the left of relational operator.
300 	TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode();
301 	if (!symbol) {
302 		error(binOp->getLine(), "Invalid condition", "for");
303 		return false;
304 	}
305 	if (symbol->getId() != info->index.id) {
306 		error(symbol->getLine(),
307 			  "Expected loop index", symbol->getSymbol().c_str());
308 		return false;
309 	}
310 	// Relational operator is one of: > >= < <= == or !=.
311 	switch (binOp->getOp()) {
312 	case EOpEqual:
313 	case EOpNotEqual:
314 	case EOpLessThan:
315 	case EOpGreaterThan:
316 	case EOpLessThanEqual:
317 	case EOpGreaterThanEqual:
318 		break;
319 	default:
320 		error(binOp->getLine(),
321 			  "Invalid relational operator",
322 			  getOperatorString(binOp->getOp()));
323 		break;
324 	}
325 	// Loop index must be compared with a constant.
326 	if (!isConstExpr(binOp->getRight())) {
327 		error(binOp->getLine(),
328 			  "Loop index cannot be compared with non-constant expression",
329 			  symbol->getSymbol().c_str());
330 		return false;
331 	}
332 
333 	return true;
334 }
335 
validateForLoopExpr(TIntermLoop * node,TLoopInfo * info)336 bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
337                                               TLoopInfo* info)
338 {
339 	TIntermNode* expr = node->getExpression();
340 	if (!expr) {
341 		error(node->getLine(), "Missing expression", "for");
342 		return false;
343 	}
344 
345 	// for expression has one of the following forms:
346 	//     loop_index++
347 	//     loop_index--
348 	//     loop_index += constant_expression
349 	//     loop_index -= constant_expression
350 	//     ++loop_index
351 	//     --loop_index
352 	// The last two forms are not specified in the spec, but I am assuming
353 	// its an oversight.
354 	TIntermUnary* unOp = expr->getAsUnaryNode();
355 	TIntermBinary* binOp = unOp ? nullptr : expr->getAsBinaryNode();
356 
357 	TOperator op = EOpNull;
358 	TIntermSymbol* symbol = nullptr;
359 	if (unOp) {
360 		op = unOp->getOp();
361 		symbol = unOp->getOperand()->getAsSymbolNode();
362 	} else if (binOp) {
363 		op = binOp->getOp();
364 		symbol = binOp->getLeft()->getAsSymbolNode();
365 	}
366 
367 	// The operand must be loop index.
368 	if (!symbol) {
369 		error(expr->getLine(), "Invalid expression", "for");
370 		return false;
371 	}
372 	if (symbol->getId() != info->index.id) {
373 		error(symbol->getLine(),
374 			  "Expected loop index", symbol->getSymbol().c_str());
375 		return false;
376 	}
377 
378 	// The operator is one of: ++ -- += -=.
379 	switch (op) {
380 		case EOpPostIncrement:
381 		case EOpPostDecrement:
382 		case EOpPreIncrement:
383 		case EOpPreDecrement:
384 			ASSERT((unOp != NULL) && (binOp == NULL));
385 			break;
386 		case EOpAddAssign:
387 		case EOpSubAssign:
388 			ASSERT((unOp == NULL) && (binOp != NULL));
389 			break;
390 		default:
391 			error(expr->getLine(), "Invalid operator", getOperatorString(op));
392 			return false;
393 	}
394 
395 	// Loop index must be incremented/decremented with a constant.
396 	if (binOp != NULL) {
397 		if (!isConstExpr(binOp->getRight())) {
398 			error(binOp->getLine(),
399 				  "Loop index cannot be modified by non-constant expression",
400 				  symbol->getSymbol().c_str());
401 			return false;
402 		}
403 	}
404 
405 	return true;
406 }
407 
validateFunctionCall(TIntermAggregate * node)408 bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
409 {
410 	ASSERT(node->getOp() == EOpFunctionCall);
411 
412 	// If not within loop body, there is nothing to check.
413 	if (!withinLoopBody())
414 		return true;
415 
416 	// List of param indices for which loop indices are used as argument.
417 	typedef std::vector<int> ParamIndex;
418 	ParamIndex pIndex;
419 	TIntermSequence& params = node->getSequence();
420 	for (TIntermSequence::size_type i = 0; i < params.size(); ++i) {
421 		TIntermSymbol* symbol = params[i]->getAsSymbolNode();
422 		if (symbol && isLoopIndex(symbol))
423 			pIndex.push_back(i);
424 	}
425 	// If none of the loop indices are used as arguments,
426 	// there is nothing to check.
427 	if (pIndex.empty())
428 		return true;
429 
430 	bool valid = true;
431 	TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
432 	TSymbol* symbol = symbolTable.find(node->getName(), GetGlobalParseContext()->getShaderVersion());
433 	ASSERT(symbol && symbol->isFunction());
434 	TFunction* function = static_cast<TFunction*>(symbol);
435 	for (ParamIndex::const_iterator i = pIndex.begin();
436 		 i != pIndex.end(); ++i) {
437 		const TParameter& param = function->getParam(*i);
438 		TQualifier qual = param.type->getQualifier();
439 		if ((qual == EvqOut) || (qual == EvqInOut)) {
440 			error(params[*i]->getLine(),
441 				  "Loop index cannot be used as argument to a function out or inout parameter",
442 				  params[*i]->getAsSymbolNode()->getSymbol().c_str());
443 			valid = false;
444 		}
445 	}
446 
447 	return valid;
448 }
449 
validateOperation(TIntermOperator * node,TIntermNode * operand)450 bool ValidateLimitations::validateOperation(TIntermOperator* node,
451                                             TIntermNode* operand) {
452 	// Check if loop index is modified in the loop body.
453 	if (!withinLoopBody() || !node->modifiesState())
454 		return true;
455 
456 	const TIntermSymbol* symbol = operand->getAsSymbolNode();
457 	if (symbol && isLoopIndex(symbol)) {
458 		error(node->getLine(),
459 			  "Loop index cannot be statically assigned to within the body of the loop",
460 			  symbol->getSymbol().c_str());
461 	}
462 	return true;
463 }
464 
isConstExpr(TIntermNode * node)465 bool ValidateLimitations::isConstExpr(TIntermNode* node)
466 {
467 	ASSERT(node);
468 	return node->getAsConstantUnion() != nullptr;
469 }
470 
isConstIndexExpr(TIntermNode * node)471 bool ValidateLimitations::isConstIndexExpr(TIntermNode* node)
472 {
473 	ASSERT(node);
474 
475 	ValidateConstIndexExpr validate(mLoopStack);
476 	node->traverse(&validate);
477 	return validate.isValid();
478 }
479 
validateIndexing(TIntermBinary * node)480 bool ValidateLimitations::validateIndexing(TIntermBinary* node)
481 {
482 	ASSERT((node->getOp() == EOpIndexDirect) ||
483 	       (node->getOp() == EOpIndexIndirect));
484 
485 	bool valid = true;
486 	TIntermTyped* index = node->getRight();
487 	// The index expression must have integral type.
488 	if (!index->isScalarInt()) {
489 		error(index->getLine(),
490 		      "Index expression must have integral type",
491 		      index->getCompleteString().c_str());
492 		valid = false;
493 	}
494 	// The index expession must be a constant-index-expression unless
495 	// the operand is a uniform in a vertex shader.
496 	TIntermTyped* operand = node->getLeft();
497 	bool skip = (mShaderType == GL_VERTEX_SHADER) &&
498 	            (operand->getQualifier() == EvqUniform);
499 	if (!skip && !isConstIndexExpr(index)) {
500 		error(index->getLine(), "Index expression must be constant", "[]");
501 		valid = false;
502 	}
503 	return valid;
504 }
505 
506