1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/comp/markv_decoder.h"
16
17 #include <cstring>
18 #include <iterator>
19 #include <numeric>
20
21 #include "source/ext_inst.h"
22 #include "source/opcode.h"
23 #include "spirv-tools/libspirv.hpp"
24
25 namespace spvtools {
26 namespace comp {
27
DecodeNonIdWord(uint32_t * word)28 spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
29 auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
30
31 if (codec) {
32 uint64_t decoded_value = 0;
33 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
34 return Diag(SPV_ERROR_INVALID_BINARY)
35 << "Failed to decode non-id word with Huffman";
36
37 if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
38 // The word decoded successfully.
39 *word = uint32_t(decoded_value);
40 assert(*word == decoded_value);
41 return SPV_SUCCESS;
42 }
43
44 // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
45 }
46
47 const size_t chunk_length =
48 model_->GetOperandVariableWidthChunkLength(operand_.type);
49 if (chunk_length) {
50 if (!reader_.ReadVariableWidthU32(word, chunk_length))
51 return Diag(SPV_ERROR_INVALID_BINARY)
52 << "Failed to decode non-id word with varint";
53 } else {
54 if (!reader_.ReadUnencoded(word))
55 return Diag(SPV_ERROR_INVALID_BINARY)
56 << "Failed to read unencoded non-id word";
57 }
58 return SPV_SUCCESS;
59 }
60
DecodeOpcodeAndNumberOfOperands(uint32_t * opcode,uint32_t * num_operands)61 spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
62 uint32_t* opcode, uint32_t* num_operands) {
63 // First try to use the Markov chain codec.
64 auto* codec =
65 model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
66 if (codec) {
67 uint64_t decoded_value = 0;
68 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
69 return Diag(SPV_ERROR_INTERNAL)
70 << "Failed to decode opcode_and_num_operands, previous opcode is "
71 << spvOpcodeString(GetPrevOpcode());
72
73 if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
74 // The word was successfully decoded.
75 *opcode = uint32_t(decoded_value & 0xFFFF);
76 *num_operands = uint32_t(decoded_value >> 16);
77 return SPV_SUCCESS;
78 }
79
80 // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
81 }
82
83 // Fallback to base-rate codec.
84 codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
85 assert(codec);
86 uint64_t decoded_value = 0;
87 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
88 return Diag(SPV_ERROR_INTERNAL)
89 << "Failed to decode opcode_and_num_operands with global codec";
90
91 if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) {
92 // Received kMarkvNoneOfTheAbove signal, fallback further.
93 return SPV_UNSUPPORTED;
94 }
95
96 *opcode = uint32_t(decoded_value & 0xFFFF);
97 *num_operands = uint32_t(decoded_value >> 16);
98 return SPV_SUCCESS;
99 }
100
DecodeMtfRankHuffman(uint64_t mtf,uint32_t fallback_method,uint32_t * rank)101 spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
102 uint32_t fallback_method,
103 uint32_t* rank) {
104 const auto* codec = GetMtfHuffmanCodec(mtf);
105 if (!codec) {
106 assert(fallback_method != kMtfNone);
107 codec = GetMtfHuffmanCodec(fallback_method);
108 }
109
110 if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
111
112 uint32_t decoded_value = 0;
113 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
114 return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
115
116 if (decoded_value == kMtfRankEncodedByValueSignal) {
117 // Decode by value.
118 if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
119 return Diag(SPV_ERROR_INTERNAL)
120 << "Failed to decode MTF rank with varint";
121 *rank += MarkvCodec::kMtfSmallestRankEncodedByValue;
122 } else {
123 // Decode using Huffman coding.
124 assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue);
125 *rank = decoded_value;
126 }
127 return SPV_SUCCESS;
128 }
129
DecodeIdWithDescriptor(uint32_t * id)130 spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
131 auto* codec =
132 model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
133
134 uint64_t mtf = kMtfNone;
135 if (codec) {
136 uint64_t decoded_value = 0;
137 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
138 return Diag(SPV_ERROR_INTERNAL)
139 << "Failed to decode descriptor with Huffman";
140
141 if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
142 const uint32_t long_descriptor = uint32_t(decoded_value);
143 mtf = GetMtfLongIdDescriptor(long_descriptor);
144 }
145 }
146
147 if (mtf == kMtfNone) {
148 if (model_->id_fallback_strategy() !=
149 MarkvModel::IdFallbackStrategy::kShortDescriptor) {
150 return SPV_UNSUPPORTED;
151 }
152
153 uint64_t decoded_value = 0;
154 if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits))
155 return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
156 const uint32_t short_descriptor = uint32_t(decoded_value);
157 if (short_descriptor == 0) {
158 // Forward declared id.
159 return SPV_UNSUPPORTED;
160 }
161 mtf = GetMtfShortIdDescriptor(short_descriptor);
162 }
163
164 return DecodeExistingId(mtf, id);
165 }
166
DecodeExistingId(uint64_t mtf,uint32_t * id)167 spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
168 assert(multi_mtf_.GetSize(mtf) > 0);
169 *id = 0;
170
171 uint32_t rank = 0;
172
173 if (multi_mtf_.GetSize(mtf) == 1) {
174 rank = 1;
175 } else {
176 const spv_result_t result =
177 DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
178 if (result != SPV_SUCCESS) return result;
179 }
180
181 assert(rank);
182 if (!multi_mtf_.ValueFromRank(mtf, rank, id))
183 return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
184
185 return SPV_SUCCESS;
186 }
187
DecodeRefId(uint32_t * id)188 spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
189 {
190 const spv_result_t result = DecodeIdWithDescriptor(id);
191 if (result != SPV_UNSUPPORTED) return result;
192 }
193
194 const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
195 SpvOp(inst_.opcode))(operand_index_);
196 uint32_t rank = 0;
197 *id = 0;
198
199 if (model_->id_fallback_strategy() ==
200 MarkvModel::IdFallbackStrategy::kRuleBased) {
201 uint64_t mtf = GetRuleBasedMtf();
202 if (mtf != kMtfNone && !can_forward_declare) {
203 return DecodeExistingId(mtf, id);
204 }
205
206 if (mtf == kMtfNone) mtf = kMtfAll;
207 {
208 const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
209 if (result != SPV_SUCCESS) return result;
210 }
211
212 if (rank == 0) {
213 // This is the first occurrence of a forward declared id.
214 *id = GetIdBound();
215 SetIdBound(*id + 1);
216 multi_mtf_.Insert(kMtfAll, *id);
217 multi_mtf_.Insert(kMtfForwardDeclared, *id);
218 if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
219 } else {
220 if (!multi_mtf_.ValueFromRank(mtf, rank, id))
221 return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
222 }
223 } else {
224 assert(can_forward_declare);
225
226 if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
227 return Diag(SPV_ERROR_INTERNAL)
228 << "Failed to decode MTF rank with varint";
229
230 if (rank == 0) {
231 // This is the first occurrence of a forward declared id.
232 *id = GetIdBound();
233 SetIdBound(*id + 1);
234 multi_mtf_.Insert(kMtfForwardDeclared, *id);
235 } else {
236 if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
237 return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
238 }
239 }
240 assert(*id);
241 return SPV_SUCCESS;
242 }
243
DecodeTypeId()244 spv_result_t MarkvDecoder::DecodeTypeId() {
245 if (inst_.opcode == SpvOpFunctionParameter) {
246 assert(!remaining_function_parameter_types_.empty());
247 inst_.type_id = remaining_function_parameter_types_.front();
248 remaining_function_parameter_types_.pop_front();
249 return SPV_SUCCESS;
250 }
251
252 {
253 const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
254 if (result != SPV_UNSUPPORTED) return result;
255 }
256
257 assert(model_->id_fallback_strategy() ==
258 MarkvModel::IdFallbackStrategy::kRuleBased);
259
260 uint64_t mtf = GetRuleBasedMtf();
261 assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
262 operand_index_));
263
264 if (mtf == kMtfNone) {
265 mtf = kMtfTypeNonFunction;
266 // Function types should have been handled by GetRuleBasedMtf.
267 assert(inst_.opcode != SpvOpFunction);
268 }
269
270 return DecodeExistingId(mtf, &inst_.type_id);
271 }
272
DecodeResultId()273 spv_result_t MarkvDecoder::DecodeResultId() {
274 uint32_t rank = 0;
275
276 const uint64_t num_still_forward_declared =
277 multi_mtf_.GetSize(kMtfForwardDeclared);
278
279 if (num_still_forward_declared) {
280 // Some ids were forward declared. Check if this id is one of them.
281 uint64_t id_was_forward_declared;
282 if (!reader_.ReadBits(&id_was_forward_declared, 1))
283 return Diag(SPV_ERROR_INVALID_BINARY)
284 << "Failed to read id_was_forward_declared flag";
285
286 if (id_was_forward_declared) {
287 if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
288 return Diag(SPV_ERROR_INVALID_BINARY)
289 << "Failed to read MTF rank of forward declared id";
290
291 if (rank) {
292 // The id was forward declared, recover it from kMtfForwardDeclared.
293 if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
294 &inst_.result_id))
295 return Diag(SPV_ERROR_INTERNAL)
296 << "Forward declared MTF rank is out of bounds";
297
298 // We can now remove the id from kMtfForwardDeclared.
299 if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
300 return Diag(SPV_ERROR_INTERNAL)
301 << "Failed to remove id from kMtfForwardDeclared";
302 }
303 }
304 }
305
306 if (inst_.result_id == 0) {
307 // The id was not forward declared, issue a new id.
308 inst_.result_id = GetIdBound();
309 SetIdBound(inst_.result_id + 1);
310 }
311
312 if (model_->id_fallback_strategy() ==
313 MarkvModel::IdFallbackStrategy::kRuleBased) {
314 if (!rank) {
315 multi_mtf_.Insert(kMtfAll, inst_.result_id);
316 }
317 }
318
319 return SPV_SUCCESS;
320 }
321
DecodeLiteralNumber(const spv_parsed_operand_t & operand)322 spv_result_t MarkvDecoder::DecodeLiteralNumber(
323 const spv_parsed_operand_t& operand) {
324 if (operand.number_bit_width <= 32) {
325 uint32_t word = 0;
326 const spv_result_t result = DecodeNonIdWord(&word);
327 if (result != SPV_SUCCESS) return result;
328 inst_words_.push_back(word);
329 } else {
330 assert(operand.number_bit_width <= 64);
331 uint64_t word = 0;
332 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
333 if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
334 return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
335 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
336 int64_t val = 0;
337 if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
338 model_->s64_block_exponent()))
339 return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
340 std::memcpy(&word, &val, 8);
341 } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
342 if (!reader_.ReadUnencoded(&word))
343 return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
344 } else {
345 return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
346 }
347 inst_words_.push_back(static_cast<uint32_t>(word));
348 inst_words_.push_back(static_cast<uint32_t>(word >> 32));
349 }
350 return SPV_SUCCESS;
351 }
352
ReadToByteBreak(size_t byte_break_if_less_than)353 bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
354 const size_t num_bits_to_next_byte =
355 GetNumBitsToNextByte(reader_.GetNumReadBits());
356 if (num_bits_to_next_byte == 0 ||
357 num_bits_to_next_byte > byte_break_if_less_than)
358 return true;
359
360 uint64_t bits = 0;
361 if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
362
363 assert(bits == 0);
364 if (bits != 0) return false;
365
366 return true;
367 }
368
DecodeModule(std::vector<uint32_t> * spirv_binary)369 spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
370 const bool header_read_success =
371 reader_.ReadUnencoded(&header_.magic_number) &&
372 reader_.ReadUnencoded(&header_.markv_version) &&
373 reader_.ReadUnencoded(&header_.markv_model) &&
374 reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
375 reader_.ReadUnencoded(&header_.spirv_version) &&
376 reader_.ReadUnencoded(&header_.spirv_generator);
377
378 if (!header_read_success)
379 return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
380
381 if (header_.markv_length_in_bits == 0)
382 return Diag(SPV_ERROR_INVALID_BINARY)
383 << "Header markv_length_in_bits field is zero";
384
385 if (header_.magic_number != MarkvCodec::kMarkvMagicNumber)
386 return Diag(SPV_ERROR_INVALID_BINARY)
387 << "MARK-V binary has incorrect magic number";
388
389 // TODO(atgoo@github.com): Print version strings.
390 if (header_.markv_version != MarkvCodec::GetMarkvVersion())
391 return Diag(SPV_ERROR_INVALID_BINARY)
392 << "MARK-V binary and the codec have different versions";
393
394 const uint32_t model_type = header_.markv_model >> 16;
395 const uint32_t model_version = header_.markv_model & 0xFFFF;
396 if (model_type != model_->model_type())
397 return Diag(SPV_ERROR_INVALID_BINARY)
398 << "MARK-V binary and the codec use different MARK-V models";
399
400 if (model_version != model_->model_version())
401 return Diag(SPV_ERROR_INVALID_BINARY)
402 << "MARK-V binary and the codec use different versions if the same "
403 << "MARK-V model";
404
405 spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
406 spirv_.resize(5, 0);
407 spirv_[0] = SpvMagicNumber;
408 spirv_[1] = header_.spirv_version;
409 spirv_[2] = header_.spirv_generator;
410
411 if (logger_) {
412 reader_.SetCallback(
413 [this](const std::string& str) { logger_->AppendBitSequence(str); });
414 }
415
416 while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
417 inst_ = {};
418 const spv_result_t decode_result = DecodeInstruction();
419 if (decode_result != SPV_SUCCESS) return decode_result;
420 }
421
422 if (validator_options_) {
423 spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()};
424 const spv_result_t result = spvValidateWithOptions(
425 context_, validator_options_, &validation_binary, nullptr);
426 if (result != SPV_SUCCESS) return result;
427 }
428
429 // Validate the decode binary
430 if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
431 !reader_.OnlyZeroesLeft()) {
432 return Diag(SPV_ERROR_INVALID_BINARY)
433 << "MARK-V binary has wrong stated bit length "
434 << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
435 }
436
437 // Decoding of the module is finished, validation state should have correct
438 // id bound.
439 spirv_[3] = GetIdBound();
440
441 *spirv_binary = std::move(spirv_);
442 return SPV_SUCCESS;
443 }
444
445 // TODO(atgoo@github.com): The implementation borrows heavily from
446 // Parser::parseOperand.
447 // Consider coupling them together in some way once MARK-V codec is more mature.
448 // For now it's better to keep the code independent for experimentation
449 // purposes.
DecodeOperand(size_t operand_offset,const spv_operand_type_t type,spv_operand_pattern_t * expected_operands)450 spv_result_t MarkvDecoder::DecodeOperand(
451 size_t operand_offset, const spv_operand_type_t type,
452 spv_operand_pattern_t* expected_operands) {
453 const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
454
455 memset(&operand_, 0, sizeof(operand_));
456
457 assert((operand_offset >> 16) == 0);
458 operand_.offset = static_cast<uint16_t>(operand_offset);
459 operand_.type = type;
460
461 // Set default values, may be updated later.
462 operand_.number_kind = SPV_NUMBER_NONE;
463 operand_.number_bit_width = 0;
464
465 const size_t first_word_index = inst_words_.size();
466
467 switch (type) {
468 case SPV_OPERAND_TYPE_RESULT_ID: {
469 const spv_result_t result = DecodeResultId();
470 if (result != SPV_SUCCESS) return result;
471
472 inst_words_.push_back(inst_.result_id);
473 SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
474 PromoteIfNeeded(inst_.result_id);
475 break;
476 }
477
478 case SPV_OPERAND_TYPE_TYPE_ID: {
479 const spv_result_t result = DecodeTypeId();
480 if (result != SPV_SUCCESS) return result;
481
482 inst_words_.push_back(inst_.type_id);
483 SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
484 PromoteIfNeeded(inst_.type_id);
485 break;
486 }
487
488 case SPV_OPERAND_TYPE_ID:
489 case SPV_OPERAND_TYPE_OPTIONAL_ID:
490 case SPV_OPERAND_TYPE_SCOPE_ID:
491 case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
492 uint32_t id = 0;
493 const spv_result_t result = DecodeRefId(&id);
494 if (result != SPV_SUCCESS) return result;
495
496 if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
497
498 if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
499 operand_.type = SPV_OPERAND_TYPE_ID;
500
501 if (opcode == SpvOpExtInst && operand_.offset == 3) {
502 // The current word is the extended instruction set id.
503 // Set the extended instruction set type for the current
504 // instruction.
505 auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
506 if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
507 return Diag(SPV_ERROR_INVALID_ID)
508 << "OpExtInst set id " << id
509 << " does not reference an OpExtInstImport result Id";
510 }
511 inst_.ext_inst_type = ext_inst_type_iter->second;
512 }
513 }
514
515 inst_words_.push_back(id);
516 SetIdBound(std::max(GetIdBound(), id + 1));
517 PromoteIfNeeded(id);
518 break;
519 }
520
521 case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
522 uint32_t word = 0;
523 const spv_result_t result = DecodeNonIdWord(&word);
524 if (result != SPV_SUCCESS) return result;
525
526 inst_words_.push_back(word);
527
528 assert(SpvOpExtInst == opcode);
529 assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
530 spv_ext_inst_desc ext_inst;
531 if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
532 return Diag(SPV_ERROR_INVALID_BINARY)
533 << "Invalid extended instruction number: " << word;
534 spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
535 break;
536 }
537
538 case SPV_OPERAND_TYPE_LITERAL_INTEGER:
539 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
540 // These are regular single-word literal integer operands.
541 // Post-parsing validation should check the range of the parsed value.
542 operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
543 // It turns out they are always unsigned integers!
544 operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
545 operand_.number_bit_width = 32;
546
547 uint32_t word = 0;
548 const spv_result_t result = DecodeNonIdWord(&word);
549 if (result != SPV_SUCCESS) return result;
550
551 inst_words_.push_back(word);
552 break;
553 }
554
555 case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
556 case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
557 operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
558 if (opcode == SpvOpSwitch) {
559 // The literal operands have the same type as the value
560 // referenced by the selector Id.
561 const uint32_t selector_id = inst_words_.at(1);
562 const auto type_id_iter = id_to_type_id_.find(selector_id);
563 if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
564 return Diag(SPV_ERROR_INVALID_BINARY)
565 << "Invalid OpSwitch: selector id " << selector_id
566 << " has no type";
567 }
568 uint32_t type_id = type_id_iter->second;
569
570 if (selector_id == type_id) {
571 // Recall that by convention, a result ID that is a type definition
572 // maps to itself.
573 return Diag(SPV_ERROR_INVALID_BINARY)
574 << "Invalid OpSwitch: selector id " << selector_id
575 << " is a type, not a value";
576 }
577 if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
578 return error;
579 if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
580 operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
581 return Diag(SPV_ERROR_INVALID_BINARY)
582 << "Invalid OpSwitch: selector id " << selector_id
583 << " is not a scalar integer";
584 }
585 } else {
586 assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
587 // The literal number type is determined by the type Id for the
588 // constant.
589 assert(inst_.type_id);
590 if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
591 return error;
592 }
593
594 if (auto error = DecodeLiteralNumber(operand_)) return error;
595
596 break;
597 }
598
599 case SPV_OPERAND_TYPE_LITERAL_STRING:
600 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
601 operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
602 std::vector<char> str;
603 auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
604
605 if (codec) {
606 std::string decoded_string;
607 const bool huffman_result =
608 codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
609 assert(huffman_result);
610 if (!huffman_result)
611 return Diag(SPV_ERROR_INVALID_BINARY)
612 << "Failed to read literal string";
613
614 if (decoded_string != "kMarkvNoneOfTheAbove") {
615 std::copy(decoded_string.begin(), decoded_string.end(),
616 std::back_inserter(str));
617 str.push_back('\0');
618 }
619 }
620
621 // The loop is expected to terminate once we encounter '\0' or exhaust
622 // the bit stream.
623 if (str.empty()) {
624 while (true) {
625 char ch = 0;
626 if (!reader_.ReadUnencoded(&ch))
627 return Diag(SPV_ERROR_INVALID_BINARY)
628 << "Failed to read literal string";
629
630 str.push_back(ch);
631
632 if (ch == '\0') break;
633 }
634 }
635
636 while (str.size() % 4 != 0) str.push_back('\0');
637
638 inst_words_.resize(inst_words_.size() + str.size() / 4);
639 std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
640
641 if (SpvOpExtInstImport == opcode) {
642 // Record the extended instruction type for the ID for this import.
643 // There is only one string literal argument to OpExtInstImport,
644 // so it's sufficient to guard this just on the opcode.
645 const spv_ext_inst_type_t ext_inst_type =
646 spvExtInstImportTypeGet(str.data());
647 if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
648 return Diag(SPV_ERROR_INVALID_BINARY)
649 << "Invalid extended instruction import '" << str.data()
650 << "'";
651 }
652 // We must have parsed a valid result ID. It's a condition
653 // of the grammar, and we only accept non-zero result Ids.
654 assert(inst_.result_id);
655 const bool inserted =
656 import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
657 .second;
658 (void)inserted;
659 assert(inserted);
660 }
661 break;
662 }
663
664 case SPV_OPERAND_TYPE_CAPABILITY:
665 case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
666 case SPV_OPERAND_TYPE_EXECUTION_MODEL:
667 case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
668 case SPV_OPERAND_TYPE_MEMORY_MODEL:
669 case SPV_OPERAND_TYPE_EXECUTION_MODE:
670 case SPV_OPERAND_TYPE_STORAGE_CLASS:
671 case SPV_OPERAND_TYPE_DIMENSIONALITY:
672 case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
673 case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
674 case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
675 case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
676 case SPV_OPERAND_TYPE_LINKAGE_TYPE:
677 case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
678 case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
679 case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
680 case SPV_OPERAND_TYPE_DECORATION:
681 case SPV_OPERAND_TYPE_BUILT_IN:
682 case SPV_OPERAND_TYPE_GROUP_OPERATION:
683 case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
684 case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
685 // A single word that is a plain enum value.
686 uint32_t word = 0;
687 const spv_result_t result = DecodeNonIdWord(&word);
688 if (result != SPV_SUCCESS) return result;
689
690 inst_words_.push_back(word);
691
692 // Map an optional operand type to its corresponding concrete type.
693 if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
694 operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
695
696 spv_operand_desc entry;
697 if (grammar_.lookupOperand(type, word, &entry)) {
698 return Diag(SPV_ERROR_INVALID_BINARY)
699 << "Invalid " << spvOperandTypeStr(operand_.type)
700 << " operand: " << word;
701 }
702
703 // Prepare to accept operands to this operand, if needed.
704 spvPushOperandTypes(entry->operandTypes, expected_operands);
705 break;
706 }
707
708 case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
709 case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
710 case SPV_OPERAND_TYPE_LOOP_CONTROL:
711 case SPV_OPERAND_TYPE_IMAGE:
712 case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
713 case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
714 case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
715 // This operand is a mask.
716 uint32_t word = 0;
717 const spv_result_t result = DecodeNonIdWord(&word);
718 if (result != SPV_SUCCESS) return result;
719
720 inst_words_.push_back(word);
721
722 // Map an optional operand type to its corresponding concrete type.
723 if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
724 operand_.type = SPV_OPERAND_TYPE_IMAGE;
725 else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
726 operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
727
728 // Check validity of set mask bits. Also prepare for operands for those
729 // masks if they have any. To get operand order correct, scan from
730 // MSB to LSB since we can only prepend operands to a pattern.
731 // The only case in the grammar where you have more than one mask bit
732 // having an operand is for image operands. See SPIR-V 3.14 Image
733 // Operands.
734 uint32_t remaining_word = word;
735 for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
736 if (remaining_word & mask) {
737 spv_operand_desc entry;
738 if (grammar_.lookupOperand(type, mask, &entry)) {
739 return Diag(SPV_ERROR_INVALID_BINARY)
740 << "Invalid " << spvOperandTypeStr(operand_.type)
741 << " operand: " << word << " has invalid mask component "
742 << mask;
743 }
744 remaining_word ^= mask;
745 spvPushOperandTypes(entry->operandTypes, expected_operands);
746 }
747 }
748 if (word == 0) {
749 // An all-zeroes mask *might* also be valid.
750 spv_operand_desc entry;
751 if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
752 // Prepare for its operands, if any.
753 spvPushOperandTypes(entry->operandTypes, expected_operands);
754 }
755 }
756 break;
757 }
758 default:
759 return Diag(SPV_ERROR_INVALID_BINARY)
760 << "Internal error: Unhandled operand type: " << type;
761 }
762
763 operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
764
765 assert(spvOperandIsConcrete(operand_.type));
766
767 parsed_operands_.push_back(operand_);
768
769 return SPV_SUCCESS;
770 }
771
DecodeInstruction()772 spv_result_t MarkvDecoder::DecodeInstruction() {
773 parsed_operands_.clear();
774 inst_words_.clear();
775
776 // Opcode/num_words placeholder, the word will be filled in later.
777 inst_words_.push_back(0);
778
779 bool num_operands_still_unknown = true;
780 {
781 uint32_t opcode = 0;
782 uint32_t num_operands = 0;
783
784 const spv_result_t opcode_decoding_result =
785 DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
786 if (opcode_decoding_result < 0) return opcode_decoding_result;
787
788 if (opcode_decoding_result == SPV_SUCCESS) {
789 inst_.num_operands = static_cast<uint16_t>(num_operands);
790 num_operands_still_unknown = false;
791 } else {
792 if (!reader_.ReadVariableWidthU32(&opcode,
793 model_->opcode_chunk_length())) {
794 return Diag(SPV_ERROR_INVALID_BINARY)
795 << "Failed to read opcode of instruction";
796 }
797 }
798
799 inst_.opcode = static_cast<uint16_t>(opcode);
800 }
801
802 const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
803
804 spv_opcode_desc opcode_desc;
805 if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
806 return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
807 }
808
809 spv_operand_pattern_t expected_operands;
810 expected_operands.reserve(opcode_desc->numTypes);
811 for (auto i = 0; i < opcode_desc->numTypes; i++) {
812 expected_operands.push_back(
813 opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
814 }
815
816 if (num_operands_still_unknown) {
817 if (!OpcodeHasFixedNumberOfOperands(opcode)) {
818 if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
819 model_->num_operands_chunk_length()))
820 return Diag(SPV_ERROR_INVALID_BINARY)
821 << "Failed to read num_operands of instruction";
822 } else {
823 inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
824 }
825 }
826
827 for (operand_index_ = 0;
828 operand_index_ < static_cast<size_t>(inst_.num_operands);
829 ++operand_index_) {
830 assert(!expected_operands.empty());
831 const spv_operand_type_t type =
832 spvTakeFirstMatchableOperand(&expected_operands);
833
834 const size_t operand_offset = inst_words_.size();
835
836 const spv_result_t decode_result =
837 DecodeOperand(operand_offset, type, &expected_operands);
838
839 if (decode_result != SPV_SUCCESS) return decode_result;
840 }
841
842 assert(inst_.num_operands == parsed_operands_.size());
843
844 // Only valid while inst_words_ and parsed_operands_ remain unchanged (until
845 // next DecodeInstruction call).
846 inst_.words = inst_words_.data();
847 inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
848 inst_.num_words = static_cast<uint16_t>(inst_words_.size());
849 inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
850
851 std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
852
853 assert(inst_.num_words ==
854 std::accumulate(
855 parsed_operands_.begin(), parsed_operands_.end(), 1,
856 [](int num_words, const spv_parsed_operand_t& operand) {
857 return num_words += operand.num_words;
858 }) &&
859 "num_words in instruction doesn't correspond to the sum of num_words"
860 "in the operands");
861
862 RecordNumberType();
863 ProcessCurInstruction();
864
865 if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte))
866 return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
867
868 if (logger_) {
869 logger_->NewLine();
870 std::stringstream ss;
871 ss << spvOpcodeString(opcode) << " ";
872 for (size_t index = 1; index < inst_words_.size(); ++index)
873 ss << inst_words_[index] << " ";
874 logger_->AppendText(ss.str());
875 logger_->NewLine();
876 logger_->NewLine();
877 if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
878 }
879
880 return SPV_SUCCESS;
881 }
882
SetNumericTypeInfoForType(spv_parsed_operand_t * parsed_operand,uint32_t type_id)883 spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
884 spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
885 assert(type_id != 0);
886 auto type_info_iter = type_id_to_number_type_info_.find(type_id);
887 if (type_info_iter == type_id_to_number_type_info_.end()) {
888 return Diag(SPV_ERROR_INVALID_BINARY)
889 << "Type Id " << type_id << " is not a type";
890 }
891
892 const NumberType& info = type_info_iter->second;
893 if (info.type == SPV_NUMBER_NONE) {
894 // This is a valid type, but for something other than a scalar number.
895 return Diag(SPV_ERROR_INVALID_BINARY)
896 << "Type Id " << type_id << " is not a scalar numeric type";
897 }
898
899 parsed_operand->number_kind = info.type;
900 parsed_operand->number_bit_width = info.bit_width;
901 // Round up the word count.
902 parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
903 return SPV_SUCCESS;
904 }
905
RecordNumberType()906 void MarkvDecoder::RecordNumberType() {
907 const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
908 if (spvOpcodeGeneratesType(opcode)) {
909 NumberType info = {SPV_NUMBER_NONE, 0};
910 if (SpvOpTypeInt == opcode) {
911 info.bit_width = inst_.words[inst_.operands[1].offset];
912 info.type = inst_.words[inst_.operands[2].offset]
913 ? SPV_NUMBER_SIGNED_INT
914 : SPV_NUMBER_UNSIGNED_INT;
915 } else if (SpvOpTypeFloat == opcode) {
916 info.bit_width = inst_.words[inst_.operands[1].offset];
917 info.type = SPV_NUMBER_FLOATING;
918 }
919 // The *result* Id of a type generating instruction is the type Id.
920 type_id_to_number_type_info_[inst_.result_id] = info;
921 }
922 }
923
924 } // namespace comp
925 } // namespace spvtools
926