// // Copyright 2016 The ANGLE Project Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // // ConstantFoldingTest.h: // Utilities for constant folding tests. // #ifndef TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_ #define TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_ #include #include "common/mathutil.h" #include "compiler/translator/tree_util/FindMain.h" #include "compiler/translator/tree_util/FindSymbolNode.h" #include "compiler/translator/tree_util/IntermTraverse.h" #include "tests/test_utils/ShaderCompileTreeTest.h" namespace sh { class TranslatorESSL; template class ConstantFinder : public TIntermTraverser { public: ConstantFinder(const std::vector &constantVector) : TIntermTraverser(true, false, false), mConstantVector(constantVector), mFaultTolerance(T()), mFound(false) {} ConstantFinder(const std::vector &constantVector, const T &faultTolerance) : TIntermTraverser(true, false, false), mConstantVector(constantVector), mFaultTolerance(faultTolerance), mFound(false) {} ConstantFinder(const T &value) : TIntermTraverser(true, false, false), mFaultTolerance(T()), mFound(false) { mConstantVector.push_back(value); } void visitConstantUnion(TIntermConstantUnion *node) { if (node->getType().getObjectSize() == mConstantVector.size()) { bool found = true; for (size_t i = 0; i < mConstantVector.size(); i++) { if (!isEqual(node->getConstantValue()[i], mConstantVector[i])) { found = false; break; } } if (found) { mFound = found; } } } bool found() const { return mFound; } private: bool isEqual(const TConstantUnion &node, const float &value) const { if (node.getType() != EbtFloat) { return false; } if (value == std::numeric_limits::infinity()) { return gl::isInf(node.getFConst()) && node.getFConst() > 0; } else if (value == -std::numeric_limits::infinity()) { return gl::isInf(node.getFConst()) && node.getFConst() < 0; } else if (gl::isNaN(value)) { // All NaNs are treated as equal. return gl::isNaN(node.getFConst()); } return mFaultTolerance >= fabsf(node.getFConst() - value); } bool isEqual(const TConstantUnion &node, const int &value) const { if (node.getType() != EbtInt) { return false; } ASSERT(mFaultTolerance < std::numeric_limits::max()); // abs() returns 0 at least on some platforms when the minimum int value is passed in (it // doesn't have a positive counterpart). return mFaultTolerance >= abs(node.getIConst() - value) && (node.getIConst() - value) != std::numeric_limits::min(); } bool isEqual(const TConstantUnion &node, const unsigned int &value) const { if (node.getType() != EbtUInt) { return false; } ASSERT(mFaultTolerance < static_cast(std::numeric_limits::max())); return static_cast(mFaultTolerance) >= abs(static_cast(node.getUConst() - value)) && static_cast(node.getUConst() - value) != std::numeric_limits::min(); } bool isEqual(const TConstantUnion &node, const bool &value) const { if (node.getType() != EbtBool) { return false; } return node.getBConst() == value; } std::vector mConstantVector; T mFaultTolerance; bool mFound; }; class ConstantFoldingTest : public ShaderCompileTreeTest { public: ConstantFoldingTest() {} protected: ::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; } ShShaderSpec getShaderSpec() const override { return SH_GLES3_1_SPEC; } template bool constantFoundInAST(T constant) { ConstantFinder finder(constant); mASTRoot->traverse(&finder); return finder.found(); } template bool constantVectorFoundInAST(const std::vector &constantVector) { ConstantFinder finder(constantVector); mASTRoot->traverse(&finder); return finder.found(); } template bool constantColumnMajorMatrixFoundInAST(const std::vector &constantMatrix) { return constantVectorFoundInAST(constantMatrix); } template bool constantVectorNearFoundInAST(const std::vector &constantVector, const T &faultTolerance) { ConstantFinder finder(constantVector, faultTolerance); mASTRoot->traverse(&finder); return finder.found(); } bool symbolFoundInAST(const char *symbolName) { return FindSymbolNode(mASTRoot, ImmutableString(symbolName)) != nullptr; } bool symbolFoundInMain(const char *symbolName) { return FindSymbolNode(FindMain(mASTRoot), ImmutableString(symbolName)) != nullptr; } }; class ConstantFoldingExpressionTest : public ConstantFoldingTest { public: ConstantFoldingExpressionTest() {} void evaluateFloat(const std::string &floatExpression); void evaluateInt(const std::string &intExpression); void evaluateUint(const std::string &uintExpression); }; } // namespace sh #endif // TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_