1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVModule.h"
18 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Serialization.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "gmock/gmock.h"
23
24 #include <memory>
25
26 using namespace mlir;
27
28 using ::testing::StrEq;
29
30 //===----------------------------------------------------------------------===//
31 // Test Fixture
32 //===----------------------------------------------------------------------===//
33
34 /// A deserialization test fixture providing minimal SPIR-V building and
35 /// diagnostic checking utilities.
36 class DeserializationTest : public ::testing::Test {
37 protected:
DeserializationTest()38 DeserializationTest() {
39 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
40 // Register a diagnostic handler to capture the diagnostic so that we can
41 // check it later.
42 context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
43 diagnostic.reset(new Diagnostic(std::move(diag)));
44 });
45 }
46
47 /// Performs deserialization and returns the constructed spv.module op.
deserialize()48 spirv::OwningSPIRVModuleRef deserialize() {
49 return spirv::deserialize(binary, &context);
50 }
51
52 /// Checks there is a diagnostic generated with the given `errorMessage`.
expectDiagnostic(StringRef errorMessage)53 void expectDiagnostic(StringRef errorMessage) {
54 ASSERT_NE(nullptr, diagnostic.get());
55
56 // TODO: check error location too.
57 EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
58 }
59
60 //===--------------------------------------------------------------------===//
61 // SPIR-V builder methods
62 //===--------------------------------------------------------------------===//
63
64 /// Adds the SPIR-V module header to `binary`.
addHeader()65 void addHeader() {
66 spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0);
67 }
68
69 /// Adds the SPIR-V instruction into `binary`.
addInstruction(spirv::Opcode op,ArrayRef<uint32_t> operands)70 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
71 uint32_t wordCount = 1 + operands.size();
72 binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
73 binary.append(operands.begin(), operands.end());
74 }
75
addVoidType()76 uint32_t addVoidType() {
77 auto id = nextID++;
78 addInstruction(spirv::Opcode::OpTypeVoid, {id});
79 return id;
80 }
81
addIntType(uint32_t bitwidth)82 uint32_t addIntType(uint32_t bitwidth) {
83 auto id = nextID++;
84 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
85 return id;
86 }
87
addStructType(ArrayRef<uint32_t> memberTypes)88 uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
89 auto id = nextID++;
90 SmallVector<uint32_t, 2> words;
91 words.push_back(id);
92 words.append(memberTypes.begin(), memberTypes.end());
93 addInstruction(spirv::Opcode::OpTypeStruct, words);
94 return id;
95 }
96
addFunctionType(uint32_t retType,ArrayRef<uint32_t> paramTypes)97 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
98 auto id = nextID++;
99 SmallVector<uint32_t, 4> operands;
100 operands.push_back(id);
101 operands.push_back(retType);
102 operands.append(paramTypes.begin(), paramTypes.end());
103 addInstruction(spirv::Opcode::OpTypeFunction, operands);
104 return id;
105 }
106
addFunction(uint32_t retType,uint32_t fnType)107 uint32_t addFunction(uint32_t retType, uint32_t fnType) {
108 auto id = nextID++;
109 addInstruction(spirv::Opcode::OpFunction,
110 {retType, id,
111 static_cast<uint32_t>(spirv::FunctionControl::None),
112 fnType});
113 return id;
114 }
115
addFunctionEnd()116 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
117
addReturn()118 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
119
120 protected:
121 SmallVector<uint32_t, 5> binary;
122 uint32_t nextID = 1;
123 MLIRContext context;
124 std::unique_ptr<Diagnostic> diagnostic;
125 };
126
127 //===----------------------------------------------------------------------===//
128 // Basics
129 //===----------------------------------------------------------------------===//
130
TEST_F(DeserializationTest,EmptyModuleFailure)131 TEST_F(DeserializationTest, EmptyModuleFailure) {
132 ASSERT_FALSE(deserialize());
133 expectDiagnostic("SPIR-V binary module must have a 5-word header");
134 }
135
TEST_F(DeserializationTest,WrongMagicNumberFailure)136 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
137 addHeader();
138 binary.front() = 0xdeadbeef; // Change to a wrong magic number
139 ASSERT_FALSE(deserialize());
140 expectDiagnostic("incorrect magic number");
141 }
142
TEST_F(DeserializationTest,OnlyHeaderSuccess)143 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
144 addHeader();
145 EXPECT_TRUE(deserialize());
146 }
147
TEST_F(DeserializationTest,ZeroWordCountFailure)148 TEST_F(DeserializationTest, ZeroWordCountFailure) {
149 addHeader();
150 binary.push_back(0); // OpNop with zero word count
151
152 ASSERT_FALSE(deserialize());
153 expectDiagnostic("word count cannot be zero");
154 }
155
TEST_F(DeserializationTest,InsufficientWordFailure)156 TEST_F(DeserializationTest, InsufficientWordFailure) {
157 addHeader();
158 binary.push_back((2u << 16) |
159 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
160 // Missing word for type <id>.
161
162 ASSERT_FALSE(deserialize());
163 expectDiagnostic("insufficient words for the last instruction");
164 }
165
166 //===----------------------------------------------------------------------===//
167 // Types
168 //===----------------------------------------------------------------------===//
169
TEST_F(DeserializationTest,IntTypeMissingSignednessFailure)170 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
171 addHeader();
172 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
173
174 ASSERT_FALSE(deserialize());
175 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
176 }
177
178 //===----------------------------------------------------------------------===//
179 // StructType
180 //===----------------------------------------------------------------------===//
181
TEST_F(DeserializationTest,OpMemberNameSuccess)182 TEST_F(DeserializationTest, OpMemberNameSuccess) {
183 addHeader();
184 SmallVector<uint32_t, 5> typeDecl;
185 std::swap(typeDecl, binary);
186
187 auto int32Type = addIntType(32);
188 auto structType = addStructType({int32Type, int32Type});
189 std::swap(typeDecl, binary);
190
191 SmallVector<uint32_t, 5> operands1 = {structType, 0};
192 spirv::encodeStringLiteralInto(operands1, "i1");
193 addInstruction(spirv::Opcode::OpMemberName, operands1);
194
195 SmallVector<uint32_t, 5> operands2 = {structType, 1};
196 spirv::encodeStringLiteralInto(operands2, "i2");
197 addInstruction(spirv::Opcode::OpMemberName, operands2);
198
199 binary.append(typeDecl.begin(), typeDecl.end());
200 EXPECT_TRUE(deserialize());
201 }
202
TEST_F(DeserializationTest,OpMemberNameMissingOperands)203 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
204 addHeader();
205 SmallVector<uint32_t, 5> typeDecl;
206 std::swap(typeDecl, binary);
207
208 auto int32Type = addIntType(32);
209 auto int64Type = addIntType(64);
210 auto structType = addStructType({int32Type, int64Type});
211 std::swap(typeDecl, binary);
212
213 SmallVector<uint32_t, 5> operands1 = {structType};
214 addInstruction(spirv::Opcode::OpMemberName, operands1);
215
216 binary.append(typeDecl.begin(), typeDecl.end());
217 ASSERT_FALSE(deserialize());
218 expectDiagnostic("OpMemberName must have at least 3 operands");
219 }
220
TEST_F(DeserializationTest,OpMemberNameExcessOperands)221 TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
222 addHeader();
223 SmallVector<uint32_t, 5> typeDecl;
224 std::swap(typeDecl, binary);
225
226 auto int32Type = addIntType(32);
227 auto structType = addStructType({int32Type});
228 std::swap(typeDecl, binary);
229
230 SmallVector<uint32_t, 5> operands = {structType, 0};
231 spirv::encodeStringLiteralInto(operands, "int32");
232 operands.push_back(42);
233 addInstruction(spirv::Opcode::OpMemberName, operands);
234
235 binary.append(typeDecl.begin(), typeDecl.end());
236 ASSERT_FALSE(deserialize());
237 expectDiagnostic("unexpected trailing words in OpMemberName instruction");
238 }
239
240 //===----------------------------------------------------------------------===//
241 // Functions
242 //===----------------------------------------------------------------------===//
243
TEST_F(DeserializationTest,FunctionMissingEndFailure)244 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
245 addHeader();
246 auto voidType = addVoidType();
247 auto fnType = addFunctionType(voidType, {});
248 addFunction(voidType, fnType);
249 // Missing OpFunctionEnd.
250
251 ASSERT_FALSE(deserialize());
252 expectDiagnostic("expected OpFunctionEnd instruction");
253 }
254
TEST_F(DeserializationTest,FunctionMissingParameterFailure)255 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
256 addHeader();
257 auto voidType = addVoidType();
258 auto i32Type = addIntType(32);
259 auto fnType = addFunctionType(voidType, {i32Type});
260 addFunction(voidType, fnType);
261 // Missing OpFunctionParameter.
262
263 ASSERT_FALSE(deserialize());
264 expectDiagnostic("expected OpFunctionParameter instruction");
265 }
266
TEST_F(DeserializationTest,FunctionMissingLabelForFirstBlockFailure)267 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
268 addHeader();
269 auto voidType = addVoidType();
270 auto fnType = addFunctionType(voidType, {});
271 addFunction(voidType, fnType);
272 // Missing OpLabel.
273 addReturn();
274 addFunctionEnd();
275
276 ASSERT_FALSE(deserialize());
277 expectDiagnostic("a basic block must start with OpLabel");
278 }
279
TEST_F(DeserializationTest,FunctionMalformedLabelFailure)280 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
281 addHeader();
282 auto voidType = addVoidType();
283 auto fnType = addFunctionType(voidType, {});
284 addFunction(voidType, fnType);
285 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
286 addReturn();
287 addFunctionEnd();
288
289 ASSERT_FALSE(deserialize());
290 expectDiagnostic("OpLabel should only have result <id>");
291 }
292