1 //
2 // Copyright 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // ConstantFoldingTest.h:
7 //   Utilities for constant folding tests.
8 //
9 
10 #ifndef TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_
11 #define TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_
12 
13 #include <vector>
14 
15 #include "common/mathutil.h"
16 #include "compiler/translator/tree_util/FindMain.h"
17 #include "compiler/translator/tree_util/FindSymbolNode.h"
18 #include "compiler/translator/tree_util/IntermTraverse.h"
19 #include "tests/test_utils/ShaderCompileTreeTest.h"
20 
21 namespace sh
22 {
23 
24 class TranslatorESSL;
25 
26 template <typename T>
27 class ConstantFinder : public TIntermTraverser
28 {
29   public:
ConstantFinder(const std::vector<T> & constantVector)30     ConstantFinder(const std::vector<T> &constantVector)
31         : TIntermTraverser(true, false, false),
32           mConstantVector(constantVector),
33           mFaultTolerance(T()),
34           mFound(false)
35     {}
36 
ConstantFinder(const std::vector<T> & constantVector,const T & faultTolerance)37     ConstantFinder(const std::vector<T> &constantVector, const T &faultTolerance)
38         : TIntermTraverser(true, false, false),
39           mConstantVector(constantVector),
40           mFaultTolerance(faultTolerance),
41           mFound(false)
42     {}
43 
ConstantFinder(const T & value)44     ConstantFinder(const T &value)
45         : TIntermTraverser(true, false, false), mFaultTolerance(T()), mFound(false)
46     {
47         mConstantVector.push_back(value);
48     }
49 
visitConstantUnion(TIntermConstantUnion * node)50     void visitConstantUnion(TIntermConstantUnion *node)
51     {
52         if (node->getType().getObjectSize() == mConstantVector.size())
53         {
54             bool found = true;
55             for (size_t i = 0; i < mConstantVector.size(); i++)
56             {
57                 if (!isEqual(node->getConstantValue()[i], mConstantVector[i]))
58                 {
59                     found = false;
60                     break;
61                 }
62             }
63             if (found)
64             {
65                 mFound = found;
66             }
67         }
68     }
69 
found()70     bool found() const { return mFound; }
71 
72   private:
isEqual(const TConstantUnion & node,const float & value)73     bool isEqual(const TConstantUnion &node, const float &value) const
74     {
75         if (node.getType() != EbtFloat)
76         {
77             return false;
78         }
79         if (value == std::numeric_limits<float>::infinity())
80         {
81             return gl::isInf(node.getFConst()) && node.getFConst() > 0;
82         }
83         else if (value == -std::numeric_limits<float>::infinity())
84         {
85             return gl::isInf(node.getFConst()) && node.getFConst() < 0;
86         }
87         else if (gl::isNaN(value))
88         {
89             // All NaNs are treated as equal.
90             return gl::isNaN(node.getFConst());
91         }
92         return mFaultTolerance >= fabsf(node.getFConst() - value);
93     }
94 
isEqual(const TConstantUnion & node,const int & value)95     bool isEqual(const TConstantUnion &node, const int &value) const
96     {
97         if (node.getType() != EbtInt)
98         {
99             return false;
100         }
101         ASSERT(mFaultTolerance < std::numeric_limits<int>::max());
102         // abs() returns 0 at least on some platforms when the minimum int value is passed in (it
103         // doesn't have a positive counterpart).
104         return mFaultTolerance >= abs(node.getIConst() - value) &&
105                (node.getIConst() - value) != std::numeric_limits<int>::min();
106     }
107 
isEqual(const TConstantUnion & node,const unsigned int & value)108     bool isEqual(const TConstantUnion &node, const unsigned int &value) const
109     {
110         if (node.getType() != EbtUInt)
111         {
112             return false;
113         }
114         ASSERT(mFaultTolerance < static_cast<unsigned int>(std::numeric_limits<int>::max()));
115         return static_cast<int>(mFaultTolerance) >=
116                    abs(static_cast<int>(node.getUConst() - value)) &&
117                static_cast<int>(node.getUConst() - value) != std::numeric_limits<int>::min();
118     }
119 
isEqual(const TConstantUnion & node,const bool & value)120     bool isEqual(const TConstantUnion &node, const bool &value) const
121     {
122         if (node.getType() != EbtBool)
123         {
124             return false;
125         }
126         return node.getBConst() == value;
127     }
128 
129     std::vector<T> mConstantVector;
130     T mFaultTolerance;
131     bool mFound;
132 };
133 
134 class ConstantFoldingTest : public ShaderCompileTreeTest
135 {
136   public:
ConstantFoldingTest()137     ConstantFoldingTest() {}
138 
139   protected:
getShaderType()140     ::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; }
getShaderSpec()141     ShShaderSpec getShaderSpec() const override { return SH_GLES3_1_SPEC; }
142 
143     template <typename T>
constantFoundInAST(T constant)144     bool constantFoundInAST(T constant)
145     {
146         ConstantFinder<T> finder(constant);
147         mASTRoot->traverse(&finder);
148         return finder.found();
149     }
150 
151     template <typename T>
constantVectorFoundInAST(const std::vector<T> & constantVector)152     bool constantVectorFoundInAST(const std::vector<T> &constantVector)
153     {
154         ConstantFinder<T> finder(constantVector);
155         mASTRoot->traverse(&finder);
156         return finder.found();
157     }
158 
159     template <typename T>
constantColumnMajorMatrixFoundInAST(const std::vector<T> & constantMatrix)160     bool constantColumnMajorMatrixFoundInAST(const std::vector<T> &constantMatrix)
161     {
162         return constantVectorFoundInAST(constantMatrix);
163     }
164 
165     template <typename T>
constantVectorNearFoundInAST(const std::vector<T> & constantVector,const T & faultTolerance)166     bool constantVectorNearFoundInAST(const std::vector<T> &constantVector, const T &faultTolerance)
167     {
168         ConstantFinder<T> finder(constantVector, faultTolerance);
169         mASTRoot->traverse(&finder);
170         return finder.found();
171     }
172 
symbolFoundInAST(const char * symbolName)173     bool symbolFoundInAST(const char *symbolName)
174     {
175         return FindSymbolNode(mASTRoot, ImmutableString(symbolName)) != nullptr;
176     }
177 
symbolFoundInMain(const char * symbolName)178     bool symbolFoundInMain(const char *symbolName)
179     {
180         return FindSymbolNode(FindMain(mASTRoot), ImmutableString(symbolName)) != nullptr;
181     }
182 };
183 
184 class ConstantFoldingExpressionTest : public ConstantFoldingTest
185 {
186   public:
ConstantFoldingExpressionTest()187     ConstantFoldingExpressionTest() {}
188 
189     void evaluateFloat(const std::string &floatExpression);
190     void evaluateInt(const std::string &intExpression);
191     void evaluateUint(const std::string &uintExpression);
192 };
193 
194 }  // namespace sh
195 
196 #endif  // TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_
197