1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <string>
16 #include <utility>
17 #include <vector>
18 
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include "spirv-tools/optimizer.hpp"
22 #include "spirv/1.1/spirv.h"
23 
24 namespace spvtools {
25 namespace {
26 
27 using ::testing::ContainerEq;
28 using ::testing::HasSubstr;
29 
30 // Return a string that contains the minimum instructions needed to form
31 // a valid module.  Other instructions can be appended to this string.
Header()32 std::string Header() {
33   return R"(OpCapability Shader
34 OpCapability Linkage
35 OpMemoryModel Logical GLSL450
36 )";
37 }
38 
39 // When we assemble with a target environment of SPIR-V 1.1, we expect
40 // the following in the module header version word.
41 const uint32_t kExpectedSpvVersion = 0x10100;
42 
TEST(CppInterface,SuccessfulRoundTrip)43 TEST(CppInterface, SuccessfulRoundTrip) {
44   const std::string input_text = "%2 = OpSizeOf %1 %3\n";
45   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
46 
47   std::vector<uint32_t> binary;
48   EXPECT_TRUE(t.Assemble(input_text, &binary));
49   EXPECT_TRUE(binary.size() > 5u);
50   EXPECT_EQ(SpvMagicNumber, binary[0]);
51   EXPECT_EQ(kExpectedSpvVersion, binary[1]);
52 
53   // This cannot pass validation since %1 is not defined.
54   t.SetMessageConsumer([](spv_message_level_t level, const char* source,
55                           const spv_position_t& position, const char* message) {
56     EXPECT_EQ(SPV_MSG_ERROR, level);
57     EXPECT_STREQ("input", source);
58     EXPECT_EQ(0u, position.line);
59     EXPECT_EQ(0u, position.column);
60     EXPECT_EQ(1u, position.index);
61     EXPECT_STREQ("ID 1[%1] has not been defined\n  %2 = OpSizeOf %1 %3\n",
62                  message);
63   });
64   EXPECT_FALSE(t.Validate(binary));
65 
66   std::string output_text;
67   EXPECT_TRUE(t.Disassemble(binary, &output_text));
68   EXPECT_EQ(input_text, output_text);
69 }
70 
TEST(CppInterface,AssembleEmptyModule)71 TEST(CppInterface, AssembleEmptyModule) {
72   std::vector<uint32_t> binary(10, 42);
73   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
74   EXPECT_TRUE(t.Assemble("", &binary));
75   // We only have the header.
76   EXPECT_EQ(5u, binary.size());
77   EXPECT_EQ(SpvMagicNumber, binary[0]);
78   EXPECT_EQ(kExpectedSpvVersion, binary[1]);
79 }
80 
TEST(CppInterface,AssembleOverloads)81 TEST(CppInterface, AssembleOverloads) {
82   const std::string input_text = "%2 = OpSizeOf %1 %3\n";
83   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
84   {
85     std::vector<uint32_t> binary;
86     EXPECT_TRUE(t.Assemble(input_text, &binary));
87     EXPECT_TRUE(binary.size() > 5u);
88     EXPECT_EQ(SpvMagicNumber, binary[0]);
89     EXPECT_EQ(kExpectedSpvVersion, binary[1]);
90   }
91   {
92     std::vector<uint32_t> binary;
93     EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size(), &binary));
94     EXPECT_TRUE(binary.size() > 5u);
95     EXPECT_EQ(SpvMagicNumber, binary[0]);
96     EXPECT_EQ(kExpectedSpvVersion, binary[1]);
97   }
98   {  // Ignore the last newline.
99     std::vector<uint32_t> binary;
100     EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size() - 1, &binary));
101     EXPECT_TRUE(binary.size() > 5u);
102     EXPECT_EQ(SpvMagicNumber, binary[0]);
103     EXPECT_EQ(kExpectedSpvVersion, binary[1]);
104   }
105 }
106 
TEST(CppInterface,DisassembleEmptyModule)107 TEST(CppInterface, DisassembleEmptyModule) {
108   std::string text(10, 'x');
109   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
110   int invocation_count = 0;
111   t.SetMessageConsumer(
112       [&invocation_count](spv_message_level_t level, const char* source,
113                           const spv_position_t& position, const char* message) {
114         ++invocation_count;
115         EXPECT_EQ(SPV_MSG_ERROR, level);
116         EXPECT_STREQ("input", source);
117         EXPECT_EQ(0u, position.line);
118         EXPECT_EQ(0u, position.column);
119         EXPECT_EQ(0u, position.index);
120         EXPECT_STREQ("Missing module.", message);
121       });
122   EXPECT_FALSE(t.Disassemble({}, &text));
123   EXPECT_EQ("xxxxxxxxxx", text);  // The original string is unmodified.
124   EXPECT_EQ(1, invocation_count);
125 }
126 
TEST(CppInterface,DisassembleOverloads)127 TEST(CppInterface, DisassembleOverloads) {
128   const std::string input_text = "%2 = OpSizeOf %1 %3\n";
129   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
130 
131   std::vector<uint32_t> binary;
132   EXPECT_TRUE(t.Assemble(input_text, &binary));
133 
134   {
135     std::string output_text;
136     EXPECT_TRUE(t.Disassemble(binary, &output_text));
137     EXPECT_EQ(input_text, output_text);
138   }
139   {
140     std::string output_text;
141     EXPECT_TRUE(t.Disassemble(binary.data(), binary.size(), &output_text));
142     EXPECT_EQ(input_text, output_text);
143   }
144 }
145 
TEST(CppInterface,SuccessfulValidation)146 TEST(CppInterface, SuccessfulValidation) {
147   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
148   int invocation_count = 0;
149   t.SetMessageConsumer([&invocation_count](spv_message_level_t, const char*,
150                                            const spv_position_t&, const char*) {
151     ++invocation_count;
152   });
153 
154   std::vector<uint32_t> binary;
155   EXPECT_TRUE(t.Assemble(Header(), &binary));
156   EXPECT_TRUE(t.Validate(binary));
157   EXPECT_EQ(0, invocation_count);
158 }
159 
TEST(CppInterface,ValidateOverloads)160 TEST(CppInterface, ValidateOverloads) {
161   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
162   std::vector<uint32_t> binary;
163   EXPECT_TRUE(t.Assemble(Header(), &binary));
164 
165   { EXPECT_TRUE(t.Validate(binary)); }
166   { EXPECT_TRUE(t.Validate(binary.data(), binary.size())); }
167 }
168 
TEST(CppInterface,ValidateEmptyModule)169 TEST(CppInterface, ValidateEmptyModule) {
170   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
171   int invocation_count = 0;
172   t.SetMessageConsumer(
173       [&invocation_count](spv_message_level_t level, const char* source,
174                           const spv_position_t& position, const char* message) {
175         ++invocation_count;
176         EXPECT_EQ(SPV_MSG_ERROR, level);
177         EXPECT_STREQ("input", source);
178         EXPECT_EQ(0u, position.line);
179         EXPECT_EQ(0u, position.column);
180         EXPECT_EQ(0u, position.index);
181         EXPECT_STREQ("Invalid SPIR-V magic number.", message);
182       });
183   EXPECT_FALSE(t.Validate({}));
184   EXPECT_EQ(1, invocation_count);
185 }
186 
187 // Returns the assembly for a SPIR-V module with a struct declaration
188 // with the given number of members.
MakeModuleHavingStruct(int num_members)189 std::string MakeModuleHavingStruct(int num_members) {
190   std::stringstream os;
191   os << Header();
192   os << R"(%1 = OpTypeInt 32 0
193            %2 = OpTypeStruct)";
194   for (int i = 0; i < num_members; i++) os << " %1";
195   return os.str();
196 }
197 
TEST(CppInterface,ValidateWithOptionsPass)198 TEST(CppInterface, ValidateWithOptionsPass) {
199   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
200   std::vector<uint32_t> binary;
201   EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
202   const ValidatorOptions opts;
203 
204   EXPECT_TRUE(t.Validate(binary.data(), binary.size(), opts));
205 }
206 
TEST(CppInterface,ValidateWithOptionsFail)207 TEST(CppInterface, ValidateWithOptionsFail) {
208   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
209   std::vector<uint32_t> binary;
210   EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
211   ValidatorOptions opts;
212   opts.SetUniversalLimit(spv_validator_limit_max_struct_members, 9);
213   std::stringstream os;
214   t.SetMessageConsumer([&os](spv_message_level_t, const char*,
215                              const spv_position_t&,
216                              const char* message) { os << message; });
217 
218   EXPECT_FALSE(t.Validate(binary.data(), binary.size(), opts));
219   EXPECT_THAT(
220       os.str(),
221       HasSubstr(
222           "Number of OpTypeStruct members (10) has exceeded the limit (9)"));
223 }
224 
225 // Checks that after running the given optimizer |opt| on the given |original|
226 // source code, we can get the given |optimized| source code.
CheckOptimization(const std::string & original,const std::string & optimized,const Optimizer & opt)227 void CheckOptimization(const std::string& original,
228                        const std::string& optimized, const Optimizer& opt) {
229   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
230   std::vector<uint32_t> original_binary;
231   ASSERT_TRUE(t.Assemble(original, &original_binary));
232 
233   std::vector<uint32_t> optimized_binary;
234   EXPECT_TRUE(opt.Run(original_binary.data(), original_binary.size(),
235                       &optimized_binary));
236 
237   std::string optimized_text;
238   EXPECT_TRUE(t.Disassemble(optimized_binary, &optimized_text));
239   EXPECT_EQ(optimized, optimized_text);
240 }
241 
TEST(CppInterface,OptimizeEmptyModule)242 TEST(CppInterface, OptimizeEmptyModule) {
243   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
244   std::vector<uint32_t> binary;
245   EXPECT_TRUE(t.Assemble("", &binary));
246 
247   Optimizer o(SPV_ENV_UNIVERSAL_1_1);
248   o.RegisterPass(CreateStripDebugInfoPass());
249 
250   // Fails to validate.
251   EXPECT_FALSE(o.Run(binary.data(), binary.size(), &binary));
252 }
253 
TEST(CppInterface,OptimizeModifiedModule)254 TEST(CppInterface, OptimizeModifiedModule) {
255   Optimizer o(SPV_ENV_UNIVERSAL_1_1);
256   o.RegisterPass(CreateStripDebugInfoPass());
257   CheckOptimization(Header() + "OpSource GLSL 450", Header(), o);
258 }
259 
TEST(CppInterface,OptimizeMulitplePasses)260 TEST(CppInterface, OptimizeMulitplePasses) {
261   std::string original_text = Header() +
262                               "OpSource GLSL 450 "
263                               "OpDecorate %true SpecId 1 "
264                               "%bool = OpTypeBool "
265                               "%true = OpSpecConstantTrue %bool";
266 
267   Optimizer o(SPV_ENV_UNIVERSAL_1_1);
268   o.RegisterPass(CreateStripDebugInfoPass())
269       .RegisterPass(CreateFreezeSpecConstantValuePass());
270 
271   std::string expected_text = Header() +
272                               "%bool = OpTypeBool\n"
273                               "%true = OpConstantTrue %bool\n";
274 
275   CheckOptimization(original_text, expected_text, o);
276 }
277 
TEST(CppInterface,OptimizeDoNothingWithPassToken)278 TEST(CppInterface, OptimizeDoNothingWithPassToken) {
279   CreateFreezeSpecConstantValuePass();
280   auto token = CreateUnifyConstantPass();
281 }
282 
TEST(CppInterface,OptimizeReassignPassToken)283 TEST(CppInterface, OptimizeReassignPassToken) {
284   auto token = CreateNullPass();
285   token = CreateStripDebugInfoPass();
286 
287   CheckOptimization(
288       Header() + "OpSource GLSL 450", Header(),
289       Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token)));
290 }
291 
TEST(CppInterface,OptimizeMoveConstructPassToken)292 TEST(CppInterface, OptimizeMoveConstructPassToken) {
293   auto token1 = CreateStripDebugInfoPass();
294   Optimizer::PassToken token2(std::move(token1));
295 
296   CheckOptimization(
297       Header() + "OpSource GLSL 450", Header(),
298       Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
299 }
300 
TEST(CppInterface,OptimizeMoveAssignPassToken)301 TEST(CppInterface, OptimizeMoveAssignPassToken) {
302   auto token1 = CreateStripDebugInfoPass();
303   auto token2 = CreateNullPass();
304   token2 = std::move(token1);
305 
306   CheckOptimization(
307       Header() + "OpSource GLSL 450", Header(),
308       Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
309 }
310 
TEST(CppInterface,OptimizeSameAddressForOriginalOptimizedBinary)311 TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) {
312   SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
313   std::vector<uint32_t> binary;
314   ASSERT_TRUE(t.Assemble(Header() + "OpSource GLSL 450", &binary));
315 
316   EXPECT_TRUE(Optimizer(SPV_ENV_UNIVERSAL_1_1)
317                   .RegisterPass(CreateStripDebugInfoPass())
318                   .Run(binary.data(), binary.size(), &binary));
319 
320   std::string optimized_text;
321   EXPECT_TRUE(t.Disassemble(binary, &optimized_text));
322   EXPECT_EQ(Header(), optimized_text);
323 }
324 
325 // TODO(antiagainst): tests for SetMessageConsumer().
326 
327 }  // namespace
328 }  // namespace spvtools
329