1 /*
2  * Copyright 2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pass_queue.h"
18 
19 #include "file_utils.h"
20 #include "spirit.h"
21 #include "test_utils.h"
22 #include "transformer.h"
23 #include "gtest/gtest.h"
24 
25 #include <stdint.h>
26 
27 namespace android {
28 namespace spirit {
29 
30 namespace {
31 
32 class MulToAddTransformer : public Transformer {
33 public:
transform(IMulInst * mul)34   Instruction *transform(IMulInst *mul) override {
35     auto ret = new IAddInst(mul->mResultType, mul->mOperand1, mul->mOperand2);
36     ret->setId(mul->getId());
37     return ret;
38   }
39 };
40 
41 class AddToDivTransformer : public Transformer {
42 public:
transform(IAddInst * add)43   Instruction *transform(IAddInst *add) override {
44     auto ret = new SDivInst(add->mResultType, add->mOperand1, add->mOperand2);
45     ret->setId(add->getId());
46     return ret;
47   }
48 };
49 
50 class AddMulAfterAddTransformer : public Transformer {
51 public:
transform(IAddInst * add)52   Instruction *transform(IAddInst *add) override {
53     insert(add);
54     auto ret = new IMulInst(add->mResultType, add, add);
55     ret->setId(add->getId());
56     return ret;
57   }
58 };
59 
60 class Deleter : public Transformer {
61 public:
transform(IMulInst *)62   Instruction *transform(IMulInst *) override { return nullptr; }
63 };
64 
65 class InPlaceModifyingPass : public Pass {
66 public:
run(Module * m,int * error)67   Module *run(Module *m, int *error) override {
68     m->getFloatType(64);
69     if (error) {
70       *error = 0;
71     }
72     return m;
73   }
74 };
75 
76 } // annonymous namespace
77 
78 class PassQueueTest : public ::testing::Test {
79 protected:
SetUp()80   virtual void SetUp() { mWordsGreyscale = readWords("greyscale.spv"); }
81 
82   std::vector<uint32_t> mWordsGreyscale;
83 
84 private:
readWords(const char * testFile)85   std::vector<uint32_t> readWords(const char *testFile) {
86     static const std::string testDataPath(
87         "frameworks/rs/rsov/compiler/spirit/test_data/");
88     const std::string &fullPath = getAbsolutePath(testDataPath + testFile);
89     return readFile<uint32_t>(fullPath);
90   }
91 };
92 
TEST_F(PassQueueTest,testMulToAdd)93 TEST_F(PassQueueTest, testMulToAdd) {
94   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
95 
96   ASSERT_NE(nullptr, m);
97 
98   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
99   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
100 
101   PassQueue passes;
102   passes.append(new MulToAddTransformer());
103   auto m1 = passes.run(m.get());
104 
105   ASSERT_NE(nullptr, m1);
106 
107   ASSERT_TRUE(m1->resolveIds());
108 
109   EXPECT_EQ(2, countEntity<IAddInst>(m1));
110   EXPECT_EQ(0, countEntity<IMulInst>(m1));
111 }
112 
TEST_F(PassQueueTest,testInPlaceModifying)113 TEST_F(PassQueueTest, testInPlaceModifying) {
114   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
115 
116   ASSERT_NE(nullptr, m);
117 
118   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
119   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
120   EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get()));
121 
122   PassQueue passes;
123   passes.append(new InPlaceModifyingPass());
124   auto m1 = passes.run(m.get());
125 
126   ASSERT_NE(nullptr, m1);
127 
128   ASSERT_TRUE(m1->resolveIds());
129 
130   EXPECT_EQ(1, countEntity<IAddInst>(m1));
131   EXPECT_EQ(1, countEntity<IMulInst>(m1));
132   EXPECT_EQ(2, countEntity<TypeFloatInst>(m1));
133 }
134 
TEST_F(PassQueueTest,testDeletion)135 TEST_F(PassQueueTest, testDeletion) {
136   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
137 
138   ASSERT_NE(nullptr, m.get());
139 
140   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
141 
142   PassQueue passes;
143   passes.append(new Deleter());
144   auto m1 = passes.run(m.get());
145 
146   // One of the ids from the input module is missing now.
147   ASSERT_EQ(nullptr, m1);
148 }
149 
TEST_F(PassQueueTest,testMulToAddToDiv)150 TEST_F(PassQueueTest, testMulToAddToDiv) {
151   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
152 
153   ASSERT_NE(nullptr, m);
154 
155   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
156   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
157 
158   PassQueue passes;
159   passes.append(new MulToAddTransformer());
160   passes.append(new AddToDivTransformer());
161   auto m1 = passes.run(m.get());
162 
163   ASSERT_NE(nullptr, m1);
164 
165   ASSERT_TRUE(m1->resolveIds());
166 
167   EXPECT_EQ(0, countEntity<IAddInst>(m1));
168   EXPECT_EQ(0, countEntity<IMulInst>(m1));
169   EXPECT_EQ(2, countEntity<SDivInst>(m1));
170 }
171 
TEST_F(PassQueueTest,testAMix)172 TEST_F(PassQueueTest, testAMix) {
173   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
174 
175   ASSERT_NE(nullptr, m);
176 
177   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
178   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
179   EXPECT_EQ(0, countEntity<SDivInst>(m.get()));
180   EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get()));
181 
182   PassQueue passes;
183   passes.append(new MulToAddTransformer());
184   passes.append(new AddToDivTransformer());
185   passes.append(new InPlaceModifyingPass());
186 
187   std::unique_ptr<Module> m1(passes.run(m.get()));
188 
189   ASSERT_NE(nullptr, m1);
190 
191   ASSERT_TRUE(m1->resolveIds());
192 
193   EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
194   EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
195   EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
196   EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get()));
197 }
198 
TEST_F(PassQueueTest,testAnotherMix)199 TEST_F(PassQueueTest, testAnotherMix) {
200   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
201 
202   ASSERT_NE(nullptr, m);
203 
204   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
205   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
206   EXPECT_EQ(0, countEntity<SDivInst>(m.get()));
207   EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get()));
208 
209   PassQueue passes;
210   passes.append(new InPlaceModifyingPass());
211   passes.append(new MulToAddTransformer());
212   passes.append(new AddToDivTransformer());
213   auto outputWords = passes.runAndSerialize(m.get());
214 
215   std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
216 
217   ASSERT_NE(nullptr, m1);
218 
219   ASSERT_TRUE(m1->resolveIds());
220 
221   EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
222   EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
223   EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
224   EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get()));
225 }
226 
TEST_F(PassQueueTest,testMulToAddToDivFromWords)227 TEST_F(PassQueueTest, testMulToAddToDivFromWords) {
228   PassQueue passes;
229   passes.append(new MulToAddTransformer());
230   passes.append(new AddToDivTransformer());
231   auto outputWords = passes.run(std::move(mWordsGreyscale));
232 
233   std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
234 
235   ASSERT_NE(nullptr, m1);
236 
237   ASSERT_TRUE(m1->resolveIds());
238 
239   EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
240   EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
241   EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
242 }
243 
TEST_F(PassQueueTest,testMulToAddToDivToWords)244 TEST_F(PassQueueTest, testMulToAddToDivToWords) {
245   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
246 
247   ASSERT_NE(nullptr, m);
248 
249   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
250   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
251 
252   PassQueue passes;
253   passes.append(new MulToAddTransformer());
254   passes.append(new AddToDivTransformer());
255   auto outputWords = passes.runAndSerialize(m.get());
256 
257   std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
258 
259   ASSERT_NE(nullptr, m1);
260 
261   ASSERT_TRUE(m1->resolveIds());
262 
263   EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
264   EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
265   EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
266 }
267 
TEST_F(PassQueueTest,testAddMulAfterAdd)268 TEST_F(PassQueueTest, testAddMulAfterAdd) {
269   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
270 
271   ASSERT_NE(nullptr, m);
272 
273   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
274   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
275 
276   constexpr int kNumMulToAdd = 100;
277 
278   PassQueue passes;
279   for (int i = 0; i < kNumMulToAdd; i++) {
280     passes.append(new AddMulAfterAddTransformer());
281   }
282   auto outputWords = passes.runAndSerialize(m.get());
283 
284   std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
285 
286   ASSERT_NE(nullptr, m1);
287 
288   ASSERT_TRUE(m1->resolveIds());
289 
290   EXPECT_EQ(1, countEntity<IAddInst>(m1.get()));
291   EXPECT_EQ(1 + kNumMulToAdd, countEntity<IMulInst>(m1.get()));
292 }
293 
294 } // namespace spirit
295 } // namespace android
296