• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVModule.h"
18 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Serialization.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "gmock/gmock.h"
23 
24 #include <memory>
25 
26 using namespace mlir;
27 
28 using ::testing::StrEq;
29 
30 //===----------------------------------------------------------------------===//
31 // Test Fixture
32 //===----------------------------------------------------------------------===//
33 
34 /// A deserialization test fixture providing minimal SPIR-V building and
35 /// diagnostic checking utilities.
36 class DeserializationTest : public ::testing::Test {
37 protected:
DeserializationTest()38   DeserializationTest() {
39     context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
40     // Register a diagnostic handler to capture the diagnostic so that we can
41     // check it later.
42     context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
43       diagnostic.reset(new Diagnostic(std::move(diag)));
44     });
45   }
46 
47   /// Performs deserialization and returns the constructed spv.module op.
deserialize()48   spirv::OwningSPIRVModuleRef deserialize() {
49     return spirv::deserialize(binary, &context);
50   }
51 
52   /// Checks there is a diagnostic generated with the given `errorMessage`.
expectDiagnostic(StringRef errorMessage)53   void expectDiagnostic(StringRef errorMessage) {
54     ASSERT_NE(nullptr, diagnostic.get());
55 
56     // TODO: check error location too.
57     EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
58   }
59 
60   //===--------------------------------------------------------------------===//
61   // SPIR-V builder methods
62   //===--------------------------------------------------------------------===//
63 
64   /// Adds the SPIR-V module header to `binary`.
addHeader()65   void addHeader() {
66     spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0);
67   }
68 
69   /// Adds the SPIR-V instruction into `binary`.
addInstruction(spirv::Opcode op,ArrayRef<uint32_t> operands)70   void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
71     uint32_t wordCount = 1 + operands.size();
72     binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
73     binary.append(operands.begin(), operands.end());
74   }
75 
addVoidType()76   uint32_t addVoidType() {
77     auto id = nextID++;
78     addInstruction(spirv::Opcode::OpTypeVoid, {id});
79     return id;
80   }
81 
addIntType(uint32_t bitwidth)82   uint32_t addIntType(uint32_t bitwidth) {
83     auto id = nextID++;
84     addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
85     return id;
86   }
87 
addStructType(ArrayRef<uint32_t> memberTypes)88   uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
89     auto id = nextID++;
90     SmallVector<uint32_t, 2> words;
91     words.push_back(id);
92     words.append(memberTypes.begin(), memberTypes.end());
93     addInstruction(spirv::Opcode::OpTypeStruct, words);
94     return id;
95   }
96 
addFunctionType(uint32_t retType,ArrayRef<uint32_t> paramTypes)97   uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
98     auto id = nextID++;
99     SmallVector<uint32_t, 4> operands;
100     operands.push_back(id);
101     operands.push_back(retType);
102     operands.append(paramTypes.begin(), paramTypes.end());
103     addInstruction(spirv::Opcode::OpTypeFunction, operands);
104     return id;
105   }
106 
addFunction(uint32_t retType,uint32_t fnType)107   uint32_t addFunction(uint32_t retType, uint32_t fnType) {
108     auto id = nextID++;
109     addInstruction(spirv::Opcode::OpFunction,
110                    {retType, id,
111                     static_cast<uint32_t>(spirv::FunctionControl::None),
112                     fnType});
113     return id;
114   }
115 
addFunctionEnd()116   void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
117 
addReturn()118   void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
119 
120 protected:
121   SmallVector<uint32_t, 5> binary;
122   uint32_t nextID = 1;
123   MLIRContext context;
124   std::unique_ptr<Diagnostic> diagnostic;
125 };
126 
127 //===----------------------------------------------------------------------===//
128 // Basics
129 //===----------------------------------------------------------------------===//
130 
TEST_F(DeserializationTest,EmptyModuleFailure)131 TEST_F(DeserializationTest, EmptyModuleFailure) {
132   ASSERT_FALSE(deserialize());
133   expectDiagnostic("SPIR-V binary module must have a 5-word header");
134 }
135 
TEST_F(DeserializationTest,WrongMagicNumberFailure)136 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
137   addHeader();
138   binary.front() = 0xdeadbeef; // Change to a wrong magic number
139   ASSERT_FALSE(deserialize());
140   expectDiagnostic("incorrect magic number");
141 }
142 
TEST_F(DeserializationTest,OnlyHeaderSuccess)143 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
144   addHeader();
145   EXPECT_TRUE(deserialize());
146 }
147 
TEST_F(DeserializationTest,ZeroWordCountFailure)148 TEST_F(DeserializationTest, ZeroWordCountFailure) {
149   addHeader();
150   binary.push_back(0); // OpNop with zero word count
151 
152   ASSERT_FALSE(deserialize());
153   expectDiagnostic("word count cannot be zero");
154 }
155 
TEST_F(DeserializationTest,InsufficientWordFailure)156 TEST_F(DeserializationTest, InsufficientWordFailure) {
157   addHeader();
158   binary.push_back((2u << 16) |
159                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
160   // Missing word for type <id>.
161 
162   ASSERT_FALSE(deserialize());
163   expectDiagnostic("insufficient words for the last instruction");
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // Types
168 //===----------------------------------------------------------------------===//
169 
TEST_F(DeserializationTest,IntTypeMissingSignednessFailure)170 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
171   addHeader();
172   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
173 
174   ASSERT_FALSE(deserialize());
175   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // StructType
180 //===----------------------------------------------------------------------===//
181 
TEST_F(DeserializationTest,OpMemberNameSuccess)182 TEST_F(DeserializationTest, OpMemberNameSuccess) {
183   addHeader();
184   SmallVector<uint32_t, 5> typeDecl;
185   std::swap(typeDecl, binary);
186 
187   auto int32Type = addIntType(32);
188   auto structType = addStructType({int32Type, int32Type});
189   std::swap(typeDecl, binary);
190 
191   SmallVector<uint32_t, 5> operands1 = {structType, 0};
192   spirv::encodeStringLiteralInto(operands1, "i1");
193   addInstruction(spirv::Opcode::OpMemberName, operands1);
194 
195   SmallVector<uint32_t, 5> operands2 = {structType, 1};
196   spirv::encodeStringLiteralInto(operands2, "i2");
197   addInstruction(spirv::Opcode::OpMemberName, operands2);
198 
199   binary.append(typeDecl.begin(), typeDecl.end());
200   EXPECT_TRUE(deserialize());
201 }
202 
TEST_F(DeserializationTest,OpMemberNameMissingOperands)203 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
204   addHeader();
205   SmallVector<uint32_t, 5> typeDecl;
206   std::swap(typeDecl, binary);
207 
208   auto int32Type = addIntType(32);
209   auto int64Type = addIntType(64);
210   auto structType = addStructType({int32Type, int64Type});
211   std::swap(typeDecl, binary);
212 
213   SmallVector<uint32_t, 5> operands1 = {structType};
214   addInstruction(spirv::Opcode::OpMemberName, operands1);
215 
216   binary.append(typeDecl.begin(), typeDecl.end());
217   ASSERT_FALSE(deserialize());
218   expectDiagnostic("OpMemberName must have at least 3 operands");
219 }
220 
TEST_F(DeserializationTest,OpMemberNameExcessOperands)221 TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
222   addHeader();
223   SmallVector<uint32_t, 5> typeDecl;
224   std::swap(typeDecl, binary);
225 
226   auto int32Type = addIntType(32);
227   auto structType = addStructType({int32Type});
228   std::swap(typeDecl, binary);
229 
230   SmallVector<uint32_t, 5> operands = {structType, 0};
231   spirv::encodeStringLiteralInto(operands, "int32");
232   operands.push_back(42);
233   addInstruction(spirv::Opcode::OpMemberName, operands);
234 
235   binary.append(typeDecl.begin(), typeDecl.end());
236   ASSERT_FALSE(deserialize());
237   expectDiagnostic("unexpected trailing words in OpMemberName instruction");
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // Functions
242 //===----------------------------------------------------------------------===//
243 
TEST_F(DeserializationTest,FunctionMissingEndFailure)244 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
245   addHeader();
246   auto voidType = addVoidType();
247   auto fnType = addFunctionType(voidType, {});
248   addFunction(voidType, fnType);
249   // Missing OpFunctionEnd.
250 
251   ASSERT_FALSE(deserialize());
252   expectDiagnostic("expected OpFunctionEnd instruction");
253 }
254 
TEST_F(DeserializationTest,FunctionMissingParameterFailure)255 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
256   addHeader();
257   auto voidType = addVoidType();
258   auto i32Type = addIntType(32);
259   auto fnType = addFunctionType(voidType, {i32Type});
260   addFunction(voidType, fnType);
261   // Missing OpFunctionParameter.
262 
263   ASSERT_FALSE(deserialize());
264   expectDiagnostic("expected OpFunctionParameter instruction");
265 }
266 
TEST_F(DeserializationTest,FunctionMissingLabelForFirstBlockFailure)267 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
268   addHeader();
269   auto voidType = addVoidType();
270   auto fnType = addFunctionType(voidType, {});
271   addFunction(voidType, fnType);
272   // Missing OpLabel.
273   addReturn();
274   addFunctionEnd();
275 
276   ASSERT_FALSE(deserialize());
277   expectDiagnostic("a basic block must start with OpLabel");
278 }
279 
TEST_F(DeserializationTest,FunctionMalformedLabelFailure)280 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
281   addHeader();
282   auto voidType = addVoidType();
283   auto fnType = addFunctionType(voidType, {});
284   addFunction(voidType, fnType);
285   addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
286   addReturn();
287   addFunctionEnd();
288 
289   ASSERT_FALSE(deserialize());
290   expectDiagnostic("OpLabel should only have result <id>");
291 }
292