1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/comp/markv_decoder.h"
16 
17 #include <cstring>
18 #include <iterator>
19 #include <numeric>
20 
21 #include "source/ext_inst.h"
22 #include "source/opcode.h"
23 #include "spirv-tools/libspirv.hpp"
24 
25 namespace spvtools {
26 namespace comp {
27 
DecodeNonIdWord(uint32_t * word)28 spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
29   auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
30 
31   if (codec) {
32     uint64_t decoded_value = 0;
33     if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
34       return Diag(SPV_ERROR_INVALID_BINARY)
35              << "Failed to decode non-id word with Huffman";
36 
37     if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
38       // The word decoded successfully.
39       *word = uint32_t(decoded_value);
40       assert(*word == decoded_value);
41       return SPV_SUCCESS;
42     }
43 
44     // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
45   }
46 
47   const size_t chunk_length =
48       model_->GetOperandVariableWidthChunkLength(operand_.type);
49   if (chunk_length) {
50     if (!reader_.ReadVariableWidthU32(word, chunk_length))
51       return Diag(SPV_ERROR_INVALID_BINARY)
52              << "Failed to decode non-id word with varint";
53   } else {
54     if (!reader_.ReadUnencoded(word))
55       return Diag(SPV_ERROR_INVALID_BINARY)
56              << "Failed to read unencoded non-id word";
57   }
58   return SPV_SUCCESS;
59 }
60 
DecodeOpcodeAndNumberOfOperands(uint32_t * opcode,uint32_t * num_operands)61 spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
62     uint32_t* opcode, uint32_t* num_operands) {
63   // First try to use the Markov chain codec.
64   auto* codec =
65       model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
66   if (codec) {
67     uint64_t decoded_value = 0;
68     if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
69       return Diag(SPV_ERROR_INTERNAL)
70              << "Failed to decode opcode_and_num_operands, previous opcode is "
71              << spvOpcodeString(GetPrevOpcode());
72 
73     if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
74       // The word was successfully decoded.
75       *opcode = uint32_t(decoded_value & 0xFFFF);
76       *num_operands = uint32_t(decoded_value >> 16);
77       return SPV_SUCCESS;
78     }
79 
80     // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
81   }
82 
83   // Fallback to base-rate codec.
84   codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
85   assert(codec);
86   uint64_t decoded_value = 0;
87   if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
88     return Diag(SPV_ERROR_INTERNAL)
89            << "Failed to decode opcode_and_num_operands with global codec";
90 
91   if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) {
92     // Received kMarkvNoneOfTheAbove signal, fallback further.
93     return SPV_UNSUPPORTED;
94   }
95 
96   *opcode = uint32_t(decoded_value & 0xFFFF);
97   *num_operands = uint32_t(decoded_value >> 16);
98   return SPV_SUCCESS;
99 }
100 
DecodeMtfRankHuffman(uint64_t mtf,uint32_t fallback_method,uint32_t * rank)101 spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
102                                                 uint32_t fallback_method,
103                                                 uint32_t* rank) {
104   const auto* codec = GetMtfHuffmanCodec(mtf);
105   if (!codec) {
106     assert(fallback_method != kMtfNone);
107     codec = GetMtfHuffmanCodec(fallback_method);
108   }
109 
110   if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
111 
112   uint32_t decoded_value = 0;
113   if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
114     return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
115 
116   if (decoded_value == kMtfRankEncodedByValueSignal) {
117     // Decode by value.
118     if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
119       return Diag(SPV_ERROR_INTERNAL)
120              << "Failed to decode MTF rank with varint";
121     *rank += MarkvCodec::kMtfSmallestRankEncodedByValue;
122   } else {
123     // Decode using Huffman coding.
124     assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue);
125     *rank = decoded_value;
126   }
127   return SPV_SUCCESS;
128 }
129 
DecodeIdWithDescriptor(uint32_t * id)130 spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
131   auto* codec =
132       model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
133 
134   uint64_t mtf = kMtfNone;
135   if (codec) {
136     uint64_t decoded_value = 0;
137     if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
138       return Diag(SPV_ERROR_INTERNAL)
139              << "Failed to decode descriptor with Huffman";
140 
141     if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
142       const uint32_t long_descriptor = uint32_t(decoded_value);
143       mtf = GetMtfLongIdDescriptor(long_descriptor);
144     }
145   }
146 
147   if (mtf == kMtfNone) {
148     if (model_->id_fallback_strategy() !=
149         MarkvModel::IdFallbackStrategy::kShortDescriptor) {
150       return SPV_UNSUPPORTED;
151     }
152 
153     uint64_t decoded_value = 0;
154     if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits))
155       return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
156     const uint32_t short_descriptor = uint32_t(decoded_value);
157     if (short_descriptor == 0) {
158       // Forward declared id.
159       return SPV_UNSUPPORTED;
160     }
161     mtf = GetMtfShortIdDescriptor(short_descriptor);
162   }
163 
164   return DecodeExistingId(mtf, id);
165 }
166 
DecodeExistingId(uint64_t mtf,uint32_t * id)167 spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
168   assert(multi_mtf_.GetSize(mtf) > 0);
169   *id = 0;
170 
171   uint32_t rank = 0;
172 
173   if (multi_mtf_.GetSize(mtf) == 1) {
174     rank = 1;
175   } else {
176     const spv_result_t result =
177         DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
178     if (result != SPV_SUCCESS) return result;
179   }
180 
181   assert(rank);
182   if (!multi_mtf_.ValueFromRank(mtf, rank, id))
183     return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
184 
185   return SPV_SUCCESS;
186 }
187 
DecodeRefId(uint32_t * id)188 spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
189   {
190     const spv_result_t result = DecodeIdWithDescriptor(id);
191     if (result != SPV_UNSUPPORTED) return result;
192   }
193 
194   const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
195       SpvOp(inst_.opcode))(operand_index_);
196   uint32_t rank = 0;
197   *id = 0;
198 
199   if (model_->id_fallback_strategy() ==
200       MarkvModel::IdFallbackStrategy::kRuleBased) {
201     uint64_t mtf = GetRuleBasedMtf();
202     if (mtf != kMtfNone && !can_forward_declare) {
203       return DecodeExistingId(mtf, id);
204     }
205 
206     if (mtf == kMtfNone) mtf = kMtfAll;
207     {
208       const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
209       if (result != SPV_SUCCESS) return result;
210     }
211 
212     if (rank == 0) {
213       // This is the first occurrence of a forward declared id.
214       *id = GetIdBound();
215       SetIdBound(*id + 1);
216       multi_mtf_.Insert(kMtfAll, *id);
217       multi_mtf_.Insert(kMtfForwardDeclared, *id);
218       if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
219     } else {
220       if (!multi_mtf_.ValueFromRank(mtf, rank, id))
221         return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
222     }
223   } else {
224     assert(can_forward_declare);
225 
226     if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
227       return Diag(SPV_ERROR_INTERNAL)
228              << "Failed to decode MTF rank with varint";
229 
230     if (rank == 0) {
231       // This is the first occurrence of a forward declared id.
232       *id = GetIdBound();
233       SetIdBound(*id + 1);
234       multi_mtf_.Insert(kMtfForwardDeclared, *id);
235     } else {
236       if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
237         return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
238     }
239   }
240   assert(*id);
241   return SPV_SUCCESS;
242 }
243 
DecodeTypeId()244 spv_result_t MarkvDecoder::DecodeTypeId() {
245   if (inst_.opcode == SpvOpFunctionParameter) {
246     assert(!remaining_function_parameter_types_.empty());
247     inst_.type_id = remaining_function_parameter_types_.front();
248     remaining_function_parameter_types_.pop_front();
249     return SPV_SUCCESS;
250   }
251 
252   {
253     const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
254     if (result != SPV_UNSUPPORTED) return result;
255   }
256 
257   assert(model_->id_fallback_strategy() ==
258          MarkvModel::IdFallbackStrategy::kRuleBased);
259 
260   uint64_t mtf = GetRuleBasedMtf();
261   assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
262       operand_index_));
263 
264   if (mtf == kMtfNone) {
265     mtf = kMtfTypeNonFunction;
266     // Function types should have been handled by GetRuleBasedMtf.
267     assert(inst_.opcode != SpvOpFunction);
268   }
269 
270   return DecodeExistingId(mtf, &inst_.type_id);
271 }
272 
DecodeResultId()273 spv_result_t MarkvDecoder::DecodeResultId() {
274   uint32_t rank = 0;
275 
276   const uint64_t num_still_forward_declared =
277       multi_mtf_.GetSize(kMtfForwardDeclared);
278 
279   if (num_still_forward_declared) {
280     // Some ids were forward declared. Check if this id is one of them.
281     uint64_t id_was_forward_declared;
282     if (!reader_.ReadBits(&id_was_forward_declared, 1))
283       return Diag(SPV_ERROR_INVALID_BINARY)
284              << "Failed to read id_was_forward_declared flag";
285 
286     if (id_was_forward_declared) {
287       if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
288         return Diag(SPV_ERROR_INVALID_BINARY)
289                << "Failed to read MTF rank of forward declared id";
290 
291       if (rank) {
292         // The id was forward declared, recover it from kMtfForwardDeclared.
293         if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
294                                       &inst_.result_id))
295           return Diag(SPV_ERROR_INTERNAL)
296                  << "Forward declared MTF rank is out of bounds";
297 
298         // We can now remove the id from kMtfForwardDeclared.
299         if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
300           return Diag(SPV_ERROR_INTERNAL)
301                  << "Failed to remove id from kMtfForwardDeclared";
302       }
303     }
304   }
305 
306   if (inst_.result_id == 0) {
307     // The id was not forward declared, issue a new id.
308     inst_.result_id = GetIdBound();
309     SetIdBound(inst_.result_id + 1);
310   }
311 
312   if (model_->id_fallback_strategy() ==
313       MarkvModel::IdFallbackStrategy::kRuleBased) {
314     if (!rank) {
315       multi_mtf_.Insert(kMtfAll, inst_.result_id);
316     }
317   }
318 
319   return SPV_SUCCESS;
320 }
321 
DecodeLiteralNumber(const spv_parsed_operand_t & operand)322 spv_result_t MarkvDecoder::DecodeLiteralNumber(
323     const spv_parsed_operand_t& operand) {
324   if (operand.number_bit_width <= 32) {
325     uint32_t word = 0;
326     const spv_result_t result = DecodeNonIdWord(&word);
327     if (result != SPV_SUCCESS) return result;
328     inst_words_.push_back(word);
329   } else {
330     assert(operand.number_bit_width <= 64);
331     uint64_t word = 0;
332     if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
333       if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
334         return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
335     } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
336       int64_t val = 0;
337       if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
338                                         model_->s64_block_exponent()))
339         return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
340       std::memcpy(&word, &val, 8);
341     } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
342       if (!reader_.ReadUnencoded(&word))
343         return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
344     } else {
345       return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
346     }
347     inst_words_.push_back(static_cast<uint32_t>(word));
348     inst_words_.push_back(static_cast<uint32_t>(word >> 32));
349   }
350   return SPV_SUCCESS;
351 }
352 
ReadToByteBreak(size_t byte_break_if_less_than)353 bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
354   const size_t num_bits_to_next_byte =
355       GetNumBitsToNextByte(reader_.GetNumReadBits());
356   if (num_bits_to_next_byte == 0 ||
357       num_bits_to_next_byte > byte_break_if_less_than)
358     return true;
359 
360   uint64_t bits = 0;
361   if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
362 
363   assert(bits == 0);
364   if (bits != 0) return false;
365 
366   return true;
367 }
368 
DecodeModule(std::vector<uint32_t> * spirv_binary)369 spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
370   const bool header_read_success =
371       reader_.ReadUnencoded(&header_.magic_number) &&
372       reader_.ReadUnencoded(&header_.markv_version) &&
373       reader_.ReadUnencoded(&header_.markv_model) &&
374       reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
375       reader_.ReadUnencoded(&header_.spirv_version) &&
376       reader_.ReadUnencoded(&header_.spirv_generator);
377 
378   if (!header_read_success)
379     return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
380 
381   if (header_.markv_length_in_bits == 0)
382     return Diag(SPV_ERROR_INVALID_BINARY)
383            << "Header markv_length_in_bits field is zero";
384 
385   if (header_.magic_number != MarkvCodec::kMarkvMagicNumber)
386     return Diag(SPV_ERROR_INVALID_BINARY)
387            << "MARK-V binary has incorrect magic number";
388 
389   // TODO(atgoo@github.com): Print version strings.
390   if (header_.markv_version != MarkvCodec::GetMarkvVersion())
391     return Diag(SPV_ERROR_INVALID_BINARY)
392            << "MARK-V binary and the codec have different versions";
393 
394   const uint32_t model_type = header_.markv_model >> 16;
395   const uint32_t model_version = header_.markv_model & 0xFFFF;
396   if (model_type != model_->model_type())
397     return Diag(SPV_ERROR_INVALID_BINARY)
398            << "MARK-V binary and the codec use different MARK-V models";
399 
400   if (model_version != model_->model_version())
401     return Diag(SPV_ERROR_INVALID_BINARY)
402            << "MARK-V binary and the codec use different versions if the same "
403            << "MARK-V model";
404 
405   spirv_.reserve(header_.markv_length_in_bits / 2);  // Heuristic.
406   spirv_.resize(5, 0);
407   spirv_[0] = SpvMagicNumber;
408   spirv_[1] = header_.spirv_version;
409   spirv_[2] = header_.spirv_generator;
410 
411   if (logger_) {
412     reader_.SetCallback(
413         [this](const std::string& str) { logger_->AppendBitSequence(str); });
414   }
415 
416   while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
417     inst_ = {};
418     const spv_result_t decode_result = DecodeInstruction();
419     if (decode_result != SPV_SUCCESS) return decode_result;
420   }
421 
422   if (validator_options_) {
423     spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()};
424     const spv_result_t result = spvValidateWithOptions(
425         context_, validator_options_, &validation_binary, nullptr);
426     if (result != SPV_SUCCESS) return result;
427   }
428 
429   // Validate the decode binary
430   if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
431       !reader_.OnlyZeroesLeft()) {
432     return Diag(SPV_ERROR_INVALID_BINARY)
433            << "MARK-V binary has wrong stated bit length "
434            << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
435   }
436 
437   // Decoding of the module is finished, validation state should have correct
438   // id bound.
439   spirv_[3] = GetIdBound();
440 
441   *spirv_binary = std::move(spirv_);
442   return SPV_SUCCESS;
443 }
444 
445 // TODO(atgoo@github.com): The implementation borrows heavily from
446 // Parser::parseOperand.
447 // Consider coupling them together in some way once MARK-V codec is more mature.
448 // For now it's better to keep the code independent for experimentation
449 // purposes.
DecodeOperand(size_t operand_offset,const spv_operand_type_t type,spv_operand_pattern_t * expected_operands)450 spv_result_t MarkvDecoder::DecodeOperand(
451     size_t operand_offset, const spv_operand_type_t type,
452     spv_operand_pattern_t* expected_operands) {
453   const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
454 
455   memset(&operand_, 0, sizeof(operand_));
456 
457   assert((operand_offset >> 16) == 0);
458   operand_.offset = static_cast<uint16_t>(operand_offset);
459   operand_.type = type;
460 
461   // Set default values, may be updated later.
462   operand_.number_kind = SPV_NUMBER_NONE;
463   operand_.number_bit_width = 0;
464 
465   const size_t first_word_index = inst_words_.size();
466 
467   switch (type) {
468     case SPV_OPERAND_TYPE_RESULT_ID: {
469       const spv_result_t result = DecodeResultId();
470       if (result != SPV_SUCCESS) return result;
471 
472       inst_words_.push_back(inst_.result_id);
473       SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
474       PromoteIfNeeded(inst_.result_id);
475       break;
476     }
477 
478     case SPV_OPERAND_TYPE_TYPE_ID: {
479       const spv_result_t result = DecodeTypeId();
480       if (result != SPV_SUCCESS) return result;
481 
482       inst_words_.push_back(inst_.type_id);
483       SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
484       PromoteIfNeeded(inst_.type_id);
485       break;
486     }
487 
488     case SPV_OPERAND_TYPE_ID:
489     case SPV_OPERAND_TYPE_OPTIONAL_ID:
490     case SPV_OPERAND_TYPE_SCOPE_ID:
491     case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
492       uint32_t id = 0;
493       const spv_result_t result = DecodeRefId(&id);
494       if (result != SPV_SUCCESS) return result;
495 
496       if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
497 
498       if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
499         operand_.type = SPV_OPERAND_TYPE_ID;
500 
501         if (opcode == SpvOpExtInst && operand_.offset == 3) {
502           // The current word is the extended instruction set id.
503           // Set the extended instruction set type for the current
504           // instruction.
505           auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
506           if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
507             return Diag(SPV_ERROR_INVALID_ID)
508                    << "OpExtInst set id " << id
509                    << " does not reference an OpExtInstImport result Id";
510           }
511           inst_.ext_inst_type = ext_inst_type_iter->second;
512         }
513       }
514 
515       inst_words_.push_back(id);
516       SetIdBound(std::max(GetIdBound(), id + 1));
517       PromoteIfNeeded(id);
518       break;
519     }
520 
521     case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
522       uint32_t word = 0;
523       const spv_result_t result = DecodeNonIdWord(&word);
524       if (result != SPV_SUCCESS) return result;
525 
526       inst_words_.push_back(word);
527 
528       assert(SpvOpExtInst == opcode);
529       assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
530       spv_ext_inst_desc ext_inst;
531       if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
532         return Diag(SPV_ERROR_INVALID_BINARY)
533                << "Invalid extended instruction number: " << word;
534       spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
535       break;
536     }
537 
538     case SPV_OPERAND_TYPE_LITERAL_INTEGER:
539     case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
540       // These are regular single-word literal integer operands.
541       // Post-parsing validation should check the range of the parsed value.
542       operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
543       // It turns out they are always unsigned integers!
544       operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
545       operand_.number_bit_width = 32;
546 
547       uint32_t word = 0;
548       const spv_result_t result = DecodeNonIdWord(&word);
549       if (result != SPV_SUCCESS) return result;
550 
551       inst_words_.push_back(word);
552       break;
553     }
554 
555     case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
556     case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
557       operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
558       if (opcode == SpvOpSwitch) {
559         // The literal operands have the same type as the value
560         // referenced by the selector Id.
561         const uint32_t selector_id = inst_words_.at(1);
562         const auto type_id_iter = id_to_type_id_.find(selector_id);
563         if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
564           return Diag(SPV_ERROR_INVALID_BINARY)
565                  << "Invalid OpSwitch: selector id " << selector_id
566                  << " has no type";
567         }
568         uint32_t type_id = type_id_iter->second;
569 
570         if (selector_id == type_id) {
571           // Recall that by convention, a result ID that is a type definition
572           // maps to itself.
573           return Diag(SPV_ERROR_INVALID_BINARY)
574                  << "Invalid OpSwitch: selector id " << selector_id
575                  << " is a type, not a value";
576         }
577         if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
578           return error;
579         if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
580             operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
581           return Diag(SPV_ERROR_INVALID_BINARY)
582                  << "Invalid OpSwitch: selector id " << selector_id
583                  << " is not a scalar integer";
584         }
585       } else {
586         assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
587         // The literal number type is determined by the type Id for the
588         // constant.
589         assert(inst_.type_id);
590         if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
591           return error;
592       }
593 
594       if (auto error = DecodeLiteralNumber(operand_)) return error;
595 
596       break;
597     }
598 
599     case SPV_OPERAND_TYPE_LITERAL_STRING:
600     case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
601       operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
602       std::vector<char> str;
603       auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
604 
605       if (codec) {
606         std::string decoded_string;
607         const bool huffman_result =
608             codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
609         assert(huffman_result);
610         if (!huffman_result)
611           return Diag(SPV_ERROR_INVALID_BINARY)
612                  << "Failed to read literal string";
613 
614         if (decoded_string != "kMarkvNoneOfTheAbove") {
615           std::copy(decoded_string.begin(), decoded_string.end(),
616                     std::back_inserter(str));
617           str.push_back('\0');
618         }
619       }
620 
621       // The loop is expected to terminate once we encounter '\0' or exhaust
622       // the bit stream.
623       if (str.empty()) {
624         while (true) {
625           char ch = 0;
626           if (!reader_.ReadUnencoded(&ch))
627             return Diag(SPV_ERROR_INVALID_BINARY)
628                    << "Failed to read literal string";
629 
630           str.push_back(ch);
631 
632           if (ch == '\0') break;
633         }
634       }
635 
636       while (str.size() % 4 != 0) str.push_back('\0');
637 
638       inst_words_.resize(inst_words_.size() + str.size() / 4);
639       std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
640 
641       if (SpvOpExtInstImport == opcode) {
642         // Record the extended instruction type for the ID for this import.
643         // There is only one string literal argument to OpExtInstImport,
644         // so it's sufficient to guard this just on the opcode.
645         const spv_ext_inst_type_t ext_inst_type =
646             spvExtInstImportTypeGet(str.data());
647         if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
648           return Diag(SPV_ERROR_INVALID_BINARY)
649                  << "Invalid extended instruction import '" << str.data()
650                  << "'";
651         }
652         // We must have parsed a valid result ID.  It's a condition
653         // of the grammar, and we only accept non-zero result Ids.
654         assert(inst_.result_id);
655         const bool inserted =
656             import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
657                 .second;
658         (void)inserted;
659         assert(inserted);
660       }
661       break;
662     }
663 
664     case SPV_OPERAND_TYPE_CAPABILITY:
665     case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
666     case SPV_OPERAND_TYPE_EXECUTION_MODEL:
667     case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
668     case SPV_OPERAND_TYPE_MEMORY_MODEL:
669     case SPV_OPERAND_TYPE_EXECUTION_MODE:
670     case SPV_OPERAND_TYPE_STORAGE_CLASS:
671     case SPV_OPERAND_TYPE_DIMENSIONALITY:
672     case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
673     case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
674     case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
675     case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
676     case SPV_OPERAND_TYPE_LINKAGE_TYPE:
677     case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
678     case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
679     case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
680     case SPV_OPERAND_TYPE_DECORATION:
681     case SPV_OPERAND_TYPE_BUILT_IN:
682     case SPV_OPERAND_TYPE_GROUP_OPERATION:
683     case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
684     case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
685       // A single word that is a plain enum value.
686       uint32_t word = 0;
687       const spv_result_t result = DecodeNonIdWord(&word);
688       if (result != SPV_SUCCESS) return result;
689 
690       inst_words_.push_back(word);
691 
692       // Map an optional operand type to its corresponding concrete type.
693       if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
694         operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
695 
696       spv_operand_desc entry;
697       if (grammar_.lookupOperand(type, word, &entry)) {
698         return Diag(SPV_ERROR_INVALID_BINARY)
699                << "Invalid " << spvOperandTypeStr(operand_.type)
700                << " operand: " << word;
701       }
702 
703       // Prepare to accept operands to this operand, if needed.
704       spvPushOperandTypes(entry->operandTypes, expected_operands);
705       break;
706     }
707 
708     case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
709     case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
710     case SPV_OPERAND_TYPE_LOOP_CONTROL:
711     case SPV_OPERAND_TYPE_IMAGE:
712     case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
713     case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
714     case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
715       // This operand is a mask.
716       uint32_t word = 0;
717       const spv_result_t result = DecodeNonIdWord(&word);
718       if (result != SPV_SUCCESS) return result;
719 
720       inst_words_.push_back(word);
721 
722       // Map an optional operand type to its corresponding concrete type.
723       if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
724         operand_.type = SPV_OPERAND_TYPE_IMAGE;
725       else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
726         operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
727 
728       // Check validity of set mask bits. Also prepare for operands for those
729       // masks if they have any.  To get operand order correct, scan from
730       // MSB to LSB since we can only prepend operands to a pattern.
731       // The only case in the grammar where you have more than one mask bit
732       // having an operand is for image operands.  See SPIR-V 3.14 Image
733       // Operands.
734       uint32_t remaining_word = word;
735       for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
736         if (remaining_word & mask) {
737           spv_operand_desc entry;
738           if (grammar_.lookupOperand(type, mask, &entry)) {
739             return Diag(SPV_ERROR_INVALID_BINARY)
740                    << "Invalid " << spvOperandTypeStr(operand_.type)
741                    << " operand: " << word << " has invalid mask component "
742                    << mask;
743           }
744           remaining_word ^= mask;
745           spvPushOperandTypes(entry->operandTypes, expected_operands);
746         }
747       }
748       if (word == 0) {
749         // An all-zeroes mask *might* also be valid.
750         spv_operand_desc entry;
751         if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
752           // Prepare for its operands, if any.
753           spvPushOperandTypes(entry->operandTypes, expected_operands);
754         }
755       }
756       break;
757     }
758     default:
759       return Diag(SPV_ERROR_INVALID_BINARY)
760              << "Internal error: Unhandled operand type: " << type;
761   }
762 
763   operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
764 
765   assert(spvOperandIsConcrete(operand_.type));
766 
767   parsed_operands_.push_back(operand_);
768 
769   return SPV_SUCCESS;
770 }
771 
DecodeInstruction()772 spv_result_t MarkvDecoder::DecodeInstruction() {
773   parsed_operands_.clear();
774   inst_words_.clear();
775 
776   // Opcode/num_words placeholder, the word will be filled in later.
777   inst_words_.push_back(0);
778 
779   bool num_operands_still_unknown = true;
780   {
781     uint32_t opcode = 0;
782     uint32_t num_operands = 0;
783 
784     const spv_result_t opcode_decoding_result =
785         DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
786     if (opcode_decoding_result < 0) return opcode_decoding_result;
787 
788     if (opcode_decoding_result == SPV_SUCCESS) {
789       inst_.num_operands = static_cast<uint16_t>(num_operands);
790       num_operands_still_unknown = false;
791     } else {
792       if (!reader_.ReadVariableWidthU32(&opcode,
793                                         model_->opcode_chunk_length())) {
794         return Diag(SPV_ERROR_INVALID_BINARY)
795                << "Failed to read opcode of instruction";
796       }
797     }
798 
799     inst_.opcode = static_cast<uint16_t>(opcode);
800   }
801 
802   const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
803 
804   spv_opcode_desc opcode_desc;
805   if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
806     return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
807   }
808 
809   spv_operand_pattern_t expected_operands;
810   expected_operands.reserve(opcode_desc->numTypes);
811   for (auto i = 0; i < opcode_desc->numTypes; i++) {
812     expected_operands.push_back(
813         opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
814   }
815 
816   if (num_operands_still_unknown) {
817     if (!OpcodeHasFixedNumberOfOperands(opcode)) {
818       if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
819                                         model_->num_operands_chunk_length()))
820         return Diag(SPV_ERROR_INVALID_BINARY)
821                << "Failed to read num_operands of instruction";
822     } else {
823       inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
824     }
825   }
826 
827   for (operand_index_ = 0;
828        operand_index_ < static_cast<size_t>(inst_.num_operands);
829        ++operand_index_) {
830     assert(!expected_operands.empty());
831     const spv_operand_type_t type =
832         spvTakeFirstMatchableOperand(&expected_operands);
833 
834     const size_t operand_offset = inst_words_.size();
835 
836     const spv_result_t decode_result =
837         DecodeOperand(operand_offset, type, &expected_operands);
838 
839     if (decode_result != SPV_SUCCESS) return decode_result;
840   }
841 
842   assert(inst_.num_operands == parsed_operands_.size());
843 
844   // Only valid while inst_words_ and parsed_operands_ remain unchanged (until
845   // next DecodeInstruction call).
846   inst_.words = inst_words_.data();
847   inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
848   inst_.num_words = static_cast<uint16_t>(inst_words_.size());
849   inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
850 
851   std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
852 
853   assert(inst_.num_words ==
854              std::accumulate(
855                  parsed_operands_.begin(), parsed_operands_.end(), 1,
856                  [](int num_words, const spv_parsed_operand_t& operand) {
857                    return num_words += operand.num_words;
858                  }) &&
859          "num_words in instruction doesn't correspond to the sum of num_words"
860          "in the operands");
861 
862   RecordNumberType();
863   ProcessCurInstruction();
864 
865   if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte))
866     return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
867 
868   if (logger_) {
869     logger_->NewLine();
870     std::stringstream ss;
871     ss << spvOpcodeString(opcode) << " ";
872     for (size_t index = 1; index < inst_words_.size(); ++index)
873       ss << inst_words_[index] << " ";
874     logger_->AppendText(ss.str());
875     logger_->NewLine();
876     logger_->NewLine();
877     if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
878   }
879 
880   return SPV_SUCCESS;
881 }
882 
SetNumericTypeInfoForType(spv_parsed_operand_t * parsed_operand,uint32_t type_id)883 spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
884     spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
885   assert(type_id != 0);
886   auto type_info_iter = type_id_to_number_type_info_.find(type_id);
887   if (type_info_iter == type_id_to_number_type_info_.end()) {
888     return Diag(SPV_ERROR_INVALID_BINARY)
889            << "Type Id " << type_id << " is not a type";
890   }
891 
892   const NumberType& info = type_info_iter->second;
893   if (info.type == SPV_NUMBER_NONE) {
894     // This is a valid type, but for something other than a scalar number.
895     return Diag(SPV_ERROR_INVALID_BINARY)
896            << "Type Id " << type_id << " is not a scalar numeric type";
897   }
898 
899   parsed_operand->number_kind = info.type;
900   parsed_operand->number_bit_width = info.bit_width;
901   // Round up the word count.
902   parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
903   return SPV_SUCCESS;
904 }
905 
RecordNumberType()906 void MarkvDecoder::RecordNumberType() {
907   const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
908   if (spvOpcodeGeneratesType(opcode)) {
909     NumberType info = {SPV_NUMBER_NONE, 0};
910     if (SpvOpTypeInt == opcode) {
911       info.bit_width = inst_.words[inst_.operands[1].offset];
912       info.type = inst_.words[inst_.operands[2].offset]
913                       ? SPV_NUMBER_SIGNED_INT
914                       : SPV_NUMBER_UNSIGNED_INT;
915     } else if (SpvOpTypeFloat == opcode) {
916       info.bit_width = inst_.words[inst_.operands[1].offset];
917       info.type = SPV_NUMBER_FLOATING;
918     }
919     // The *result* Id of a type generating instruction is the type Id.
920     type_id_to_number_type_info_[inst_.result_id] = info;
921   }
922 }
923 
924 }  // namespace comp
925 }  // namespace spvtools
926