1 // Copyright (c) 2018 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "source/comp/bit_stream.h" 16 #include "source/comp/markv.h" 17 #include "source/comp/markv_codec.h" 18 #include "source/comp/markv_logger.h" 19 #include "source/util/make_unique.h" 20 21 #ifndef SOURCE_COMP_MARKV_DECODER_H_ 22 #define SOURCE_COMP_MARKV_DECODER_H_ 23 24 namespace spvtools { 25 namespace comp { 26 27 class MarkvLogger; 28 29 // Decodes MARK-V buffers written by MarkvEncoder. 30 class MarkvDecoder : public MarkvCodec { 31 public: 32 // |model| is owned by the caller, must be not null and valid during the 33 // lifetime of MarkvEncoder. MarkvDecoder(spv_const_context context,const std::vector<uint8_t> & markv,const MarkvCodecOptions & options,const MarkvModel * model)34 MarkvDecoder(spv_const_context context, const std::vector<uint8_t>& markv, 35 const MarkvCodecOptions& options, const MarkvModel* model) 36 : MarkvCodec(context, GetValidatorOptions(options), model), 37 options_(options), 38 reader_(markv) { 39 SetIdBound(1); 40 parsed_operands_.reserve(25); 41 inst_words_.reserve(25); 42 } 43 ~MarkvDecoder() = default; 44 45 // Creates an internal logger which writes comments on the decoding process. CreateLogger(MarkvLogConsumer log_consumer,MarkvDebugConsumer debug_consumer)46 void CreateLogger(MarkvLogConsumer log_consumer, 47 MarkvDebugConsumer debug_consumer) { 48 logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer); 49 } 50 51 // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|. 52 // Can be called only once. Fails if data of wrong format or ends prematurely, 53 // of if validation fails. 54 spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary); 55 56 // Creates and returns validator options. Returned value owned by the caller. GetValidatorOptions(const MarkvCodecOptions & options)57 static spv_validator_options GetValidatorOptions( 58 const MarkvCodecOptions& options) { 59 return options.validate_spirv_binary ? spvValidatorOptionsCreate() 60 : nullptr; 61 } 62 63 private: 64 // Describes the format of a typed literal number. 65 struct NumberType { 66 spv_number_kind_t type; 67 uint32_t bit_width; 68 }; 69 70 // Reads a single bit from reader_. The read bit is stored in |bit|. 71 // Returns false iff reader_ fails. ReadBit(bool * bit)72 bool ReadBit(bool* bit) { 73 uint64_t bits = 0; 74 const bool result = reader_.ReadBits(&bits, 1); 75 if (result) *bit = bits ? true : false; 76 return result; 77 }; 78 79 // Returns ReadBit bound to the class object. GetReadBitCallback()80 std::function<bool(bool*)> GetReadBitCallback() { 81 return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1); 82 } 83 84 // Reads a single non-id word from bit stream. operand_.type determines if 85 // the word needs to be decoded and how. 86 spv_result_t DecodeNonIdWord(uint32_t* word); 87 88 // Reads and decodes both opcode and num_operands as a single code. 89 // Returns SPV_UNSUPPORTED iff no suitable codec was found. 90 spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode, 91 uint32_t* num_operands); 92 93 // Reads mtf rank from bit stream. |mtf| is used to determine the codec 94 // scheme. |fallback_method| is used if no codec defined for |mtf|. 95 spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method, 96 uint32_t* rank); 97 98 // Reads id using coding based on mtf associated with the id descriptor. 99 // Returns SPV_UNSUPPORTED iff fallback method needs to be used. 100 spv_result_t DecodeIdWithDescriptor(uint32_t* id); 101 102 // Reads id using coding based on the given |mtf|, which is expected to 103 // contain the needed |id|. 104 spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id); 105 106 // Reads type id of the current instruction if can't be inferred. 107 spv_result_t DecodeTypeId(); 108 109 // Reads result id of the current instruction if can't be inferred. 110 spv_result_t DecodeResultId(); 111 112 // Reads id which is neither type nor result id. 113 spv_result_t DecodeRefId(uint32_t* id); 114 115 // Reads and discards bits until the beginning of the next byte if the 116 // number of bits until the next byte is less than |byte_break_if_less_than|. 117 bool ReadToByteBreak(size_t byte_break_if_less_than); 118 119 // Returns instruction words decoded up to this point. GetInstWords()120 const uint32_t* GetInstWords() const override { return inst_words_.data(); } 121 122 // Reads a literal number as it is described in |operand| from the bit stream, 123 // decodes and writes it to spirv_. 124 spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand); 125 126 // Reads instruction from bit stream, decodes and validates it. 127 // Decoded instruction is valid until the next call of DecodeInstruction(). 128 spv_result_t DecodeInstruction(); 129 130 // Read operand from the stream decodes and validates it. 131 spv_result_t DecodeOperand(size_t operand_offset, 132 const spv_operand_type_t type, 133 spv_operand_pattern_t* expected_operands); 134 135 // Records the numeric type for an operand according to the type information 136 // associated with the given non-zero type Id. This can fail if the type Id 137 // is not a type Id, or if the type Id does not reference a scalar numeric 138 // type. On success, return SPV_SUCCESS and populates the num_words, 139 // number_kind, and number_bit_width fields of parsed_operand. 140 spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand, 141 uint32_t type_id); 142 143 // Records the number type for the current instruction, if it generates a 144 // type. For types that aren't scalar numbers, record something with number 145 // kind SPV_NUMBER_NONE. 146 void RecordNumberType(); 147 148 MarkvCodecOptions options_; 149 150 // Temporary sink where decoded SPIR-V words are written. Once it contains the 151 // entire module, the container is moved and returned. 152 std::vector<uint32_t> spirv_; 153 154 // Bit stream containing encoded data. 155 BitReaderWord64 reader_; 156 157 // Temporary storage for operands of the currently parsed instruction. 158 // Valid until next DecodeInstruction call. 159 std::vector<spv_parsed_operand_t> parsed_operands_; 160 161 // Temporary storage for current instruction words. 162 // Valid until next DecodeInstruction call. 163 std::vector<uint32_t> inst_words_; 164 165 // Maps a type ID to its number type description. 166 std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_; 167 168 // Maps an ExtInstImport id to the extended instruction type. 169 std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_; 170 }; 171 172 } // namespace comp 173 } // namespace spvtools 174 175 #endif // SOURCE_COMP_MARKV_DECODER_H_ 176