1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/comp/bit_stream.h"
16 #include "source/comp/markv.h"
17 #include "source/comp/markv_codec.h"
18 #include "source/comp/markv_logger.h"
19 #include "source/util/make_unique.h"
20 
21 #ifndef SOURCE_COMP_MARKV_DECODER_H_
22 #define SOURCE_COMP_MARKV_DECODER_H_
23 
24 namespace spvtools {
25 namespace comp {
26 
27 class MarkvLogger;
28 
29 // Decodes MARK-V buffers written by MarkvEncoder.
30 class MarkvDecoder : public MarkvCodec {
31  public:
32   // |model| is owned by the caller, must be not null and valid during the
33   // lifetime of MarkvEncoder.
MarkvDecoder(spv_const_context context,const std::vector<uint8_t> & markv,const MarkvCodecOptions & options,const MarkvModel * model)34   MarkvDecoder(spv_const_context context, const std::vector<uint8_t>& markv,
35                const MarkvCodecOptions& options, const MarkvModel* model)
36       : MarkvCodec(context, GetValidatorOptions(options), model),
37         options_(options),
38         reader_(markv) {
39     SetIdBound(1);
40     parsed_operands_.reserve(25);
41     inst_words_.reserve(25);
42   }
43   ~MarkvDecoder() = default;
44 
45   // Creates an internal logger which writes comments on the decoding process.
CreateLogger(MarkvLogConsumer log_consumer,MarkvDebugConsumer debug_consumer)46   void CreateLogger(MarkvLogConsumer log_consumer,
47                     MarkvDebugConsumer debug_consumer) {
48     logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer);
49   }
50 
51   // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
52   // Can be called only once. Fails if data of wrong format or ends prematurely,
53   // of if validation fails.
54   spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
55 
56   // Creates and returns validator options. Returned value owned by the caller.
GetValidatorOptions(const MarkvCodecOptions & options)57   static spv_validator_options GetValidatorOptions(
58       const MarkvCodecOptions& options) {
59     return options.validate_spirv_binary ? spvValidatorOptionsCreate()
60                                          : nullptr;
61   }
62 
63  private:
64   // Describes the format of a typed literal number.
65   struct NumberType {
66     spv_number_kind_t type;
67     uint32_t bit_width;
68   };
69 
70   // Reads a single bit from reader_. The read bit is stored in |bit|.
71   // Returns false iff reader_ fails.
ReadBit(bool * bit)72   bool ReadBit(bool* bit) {
73     uint64_t bits = 0;
74     const bool result = reader_.ReadBits(&bits, 1);
75     if (result) *bit = bits ? true : false;
76     return result;
77   };
78 
79   // Returns ReadBit bound to the class object.
GetReadBitCallback()80   std::function<bool(bool*)> GetReadBitCallback() {
81     return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
82   }
83 
84   // Reads a single non-id word from bit stream. operand_.type determines if
85   // the word needs to be decoded and how.
86   spv_result_t DecodeNonIdWord(uint32_t* word);
87 
88   // Reads and decodes both opcode and num_operands as a single code.
89   // Returns SPV_UNSUPPORTED iff no suitable codec was found.
90   spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
91                                                uint32_t* num_operands);
92 
93   // Reads mtf rank from bit stream. |mtf| is used to determine the codec
94   // scheme. |fallback_method| is used if no codec defined for |mtf|.
95   spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
96                                     uint32_t* rank);
97 
98   // Reads id using coding based on mtf associated with the id descriptor.
99   // Returns SPV_UNSUPPORTED iff fallback method needs to be used.
100   spv_result_t DecodeIdWithDescriptor(uint32_t* id);
101 
102   // Reads id using coding based on the given |mtf|, which is expected to
103   // contain the needed |id|.
104   spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
105 
106   // Reads type id of the current instruction if can't be inferred.
107   spv_result_t DecodeTypeId();
108 
109   // Reads result id of the current instruction if can't be inferred.
110   spv_result_t DecodeResultId();
111 
112   // Reads id which is neither type nor result id.
113   spv_result_t DecodeRefId(uint32_t* id);
114 
115   // Reads and discards bits until the beginning of the next byte if the
116   // number of bits until the next byte is less than |byte_break_if_less_than|.
117   bool ReadToByteBreak(size_t byte_break_if_less_than);
118 
119   // Returns instruction words decoded up to this point.
GetInstWords()120   const uint32_t* GetInstWords() const override { return inst_words_.data(); }
121 
122   // Reads a literal number as it is described in |operand| from the bit stream,
123   // decodes and writes it to spirv_.
124   spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
125 
126   // Reads instruction from bit stream, decodes and validates it.
127   // Decoded instruction is valid until the next call of DecodeInstruction().
128   spv_result_t DecodeInstruction();
129 
130   // Read operand from the stream decodes and validates it.
131   spv_result_t DecodeOperand(size_t operand_offset,
132                              const spv_operand_type_t type,
133                              spv_operand_pattern_t* expected_operands);
134 
135   // Records the numeric type for an operand according to the type information
136   // associated with the given non-zero type Id.  This can fail if the type Id
137   // is not a type Id, or if the type Id does not reference a scalar numeric
138   // type.  On success, return SPV_SUCCESS and populates the num_words,
139   // number_kind, and number_bit_width fields of parsed_operand.
140   spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
141                                          uint32_t type_id);
142 
143   // Records the number type for the current instruction, if it generates a
144   // type. For types that aren't scalar numbers, record something with number
145   // kind SPV_NUMBER_NONE.
146   void RecordNumberType();
147 
148   MarkvCodecOptions options_;
149 
150   // Temporary sink where decoded SPIR-V words are written. Once it contains the
151   // entire module, the container is moved and returned.
152   std::vector<uint32_t> spirv_;
153 
154   // Bit stream containing encoded data.
155   BitReaderWord64 reader_;
156 
157   // Temporary storage for operands of the currently parsed instruction.
158   // Valid until next DecodeInstruction call.
159   std::vector<spv_parsed_operand_t> parsed_operands_;
160 
161   // Temporary storage for current instruction words.
162   // Valid until next DecodeInstruction call.
163   std::vector<uint32_t> inst_words_;
164 
165   // Maps a type ID to its number type description.
166   std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
167 
168   // Maps an ExtInstImport id to the extended instruction type.
169   std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
170 };
171 
172 }  // namespace comp
173 }  // namespace spvtools
174 
175 #endif  // SOURCE_COMP_MARKV_DECODER_H_
176