1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/set_spec_constant_default_value_pass.h"
16 
17 #include <algorithm>
18 #include <cctype>
19 #include <cstring>
20 #include <tuple>
21 #include <vector>
22 
23 #include "source/opt/def_use_manager.h"
24 #include "source/opt/ir_context.h"
25 #include "source/opt/type_manager.h"
26 #include "source/opt/types.h"
27 #include "source/util/make_unique.h"
28 #include "source/util/parse_number.h"
29 #include "spirv-tools/libspirv.h"
30 
31 namespace spvtools {
32 namespace opt {
33 
34 namespace {
35 using utils::EncodeNumberStatus;
36 using utils::NumberType;
37 using utils::ParseAndEncodeNumber;
38 using utils::ParseNumber;
39 
40 // Given a numeric value in a null-terminated c string and the expected type of
41 // the value, parses the string and encodes it in a vector of words. If the
42 // value is a scalar integer or floating point value, encodes the value in
43 // SPIR-V encoding format. If the value is 'false' or 'true', returns a vector
44 // with single word with value 0 or 1 respectively. Returns the vector
45 // containing the encoded value on success. Otherwise returns an empty vector.
ParseDefaultValueStr(const char * text,const analysis::Type * type)46 std::vector<uint32_t> ParseDefaultValueStr(const char* text,
47                                            const analysis::Type* type) {
48   std::vector<uint32_t> result;
49   if (!strcmp(text, "true") && type->AsBool()) {
50     result.push_back(1u);
51   } else if (!strcmp(text, "false") && type->AsBool()) {
52     result.push_back(0u);
53   } else {
54     NumberType number_type = {32, SPV_NUMBER_UNSIGNED_INT};
55     if (const auto* IT = type->AsInteger()) {
56       number_type.bitwidth = IT->width();
57       number_type.kind =
58           IT->IsSigned() ? SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
59     } else if (const auto* FT = type->AsFloat()) {
60       number_type.bitwidth = FT->width();
61       number_type.kind = SPV_NUMBER_FLOATING;
62     } else {
63       // Does not handle types other then boolean, integer or float. Returns
64       // empty vector.
65       result.clear();
66       return result;
67     }
68     EncodeNumberStatus rc = ParseAndEncodeNumber(
69         text, number_type, [&result](uint32_t word) { result.push_back(word); },
70         nullptr);
71     // Clear the result vector on failure.
72     if (rc != EncodeNumberStatus::kSuccess) {
73       result.clear();
74     }
75   }
76   return result;
77 }
78 
79 // Given a bit pattern and a type, checks if the bit pattern is compatible
80 // with the type. If so, returns the bit pattern, otherwise returns an empty
81 // bit pattern. If the given bit pattern is empty, returns an empty bit
82 // pattern. If the given type represents a SPIR-V Boolean type, the bit pattern
83 // to be returned is determined with the following standard:
84 //   If any words in the input bit pattern are non zero, returns a bit pattern
85 //   with 0x1, which represents a 'true'.
86 //   If all words in the bit pattern are zero, returns a bit pattern with 0x0,
87 //   which represents a 'false'.
ParseDefaultValueBitPattern(const std::vector<uint32_t> & input_bit_pattern,const analysis::Type * type)88 std::vector<uint32_t> ParseDefaultValueBitPattern(
89     const std::vector<uint32_t>& input_bit_pattern,
90     const analysis::Type* type) {
91   std::vector<uint32_t> result;
92   if (type->AsBool()) {
93     if (std::any_of(input_bit_pattern.begin(), input_bit_pattern.end(),
94                     [](uint32_t i) { return i != 0; })) {
95       result.push_back(1u);
96     } else {
97       result.push_back(0u);
98     }
99     return result;
100   } else if (const auto* IT = type->AsInteger()) {
101     if (IT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
102       return std::vector<uint32_t>(input_bit_pattern);
103     }
104   } else if (const auto* FT = type->AsFloat()) {
105     if (FT->width() == input_bit_pattern.size() * sizeof(uint32_t) * 8) {
106       return std::vector<uint32_t>(input_bit_pattern);
107     }
108   }
109   result.clear();
110   return result;
111 }
112 
113 // Returns true if the given instruction's result id could have a SpecId
114 // decoration.
CanHaveSpecIdDecoration(const Instruction & inst)115 bool CanHaveSpecIdDecoration(const Instruction& inst) {
116   switch (inst.opcode()) {
117     case SpvOp::SpvOpSpecConstant:
118     case SpvOp::SpvOpSpecConstantFalse:
119     case SpvOp::SpvOpSpecConstantTrue:
120       return true;
121     default:
122       return false;
123   }
124 }
125 
126 // Given a decoration group defining instruction that is decorated with SpecId
127 // decoration, finds the spec constant defining instruction which is the real
128 // target of the SpecId decoration. Returns the spec constant defining
129 // instruction if such an instruction is found, otherwise returns a nullptr.
GetSpecIdTargetFromDecorationGroup(const Instruction & decoration_group_defining_inst,analysis::DefUseManager * def_use_mgr)130 Instruction* GetSpecIdTargetFromDecorationGroup(
131     const Instruction& decoration_group_defining_inst,
132     analysis::DefUseManager* def_use_mgr) {
133   // Find the OpGroupDecorate instruction which consumes the given decoration
134   // group. Note that the given decoration group has SpecId decoration, which
135   // is unique for different spec constants. So the decoration group cannot be
136   // consumed by different OpGroupDecorate instructions. Therefore we only need
137   // the first OpGroupDecoration instruction that uses the given decoration
138   // group.
139   Instruction* group_decorate_inst = nullptr;
140   if (def_use_mgr->WhileEachUser(&decoration_group_defining_inst,
141                                  [&group_decorate_inst](Instruction* user) {
142                                    if (user->opcode() ==
143                                        SpvOp::SpvOpGroupDecorate) {
144                                      group_decorate_inst = user;
145                                      return false;
146                                    }
147                                    return true;
148                                  }))
149     return nullptr;
150 
151   // Scan through the target ids of the OpGroupDecorate instruction. There
152   // should be only one spec constant target consumes the SpecId decoration.
153   // If multiple target ids are presented in the OpGroupDecorate instruction,
154   // they must be the same one that defined by an eligible spec constant
155   // instruction. If the OpGroupDecorate instruction has different target ids
156   // or a target id is not defined by an eligible spec cosntant instruction,
157   // returns a nullptr.
158   Instruction* target_inst = nullptr;
159   for (uint32_t i = 1; i < group_decorate_inst->NumInOperands(); i++) {
160     // All the operands of a OpGroupDecorate instruction should be of type
161     // SPV_OPERAND_TYPE_ID.
162     uint32_t candidate_id = group_decorate_inst->GetSingleWordInOperand(i);
163     Instruction* candidate_inst = def_use_mgr->GetDef(candidate_id);
164 
165     if (!candidate_inst) {
166       continue;
167     }
168 
169     if (!target_inst) {
170       // If the spec constant target has not been found yet, check if the
171       // candidate instruction is the target.
172       if (CanHaveSpecIdDecoration(*candidate_inst)) {
173         target_inst = candidate_inst;
174       } else {
175         // Spec id decoration should not be applied on other instructions.
176         // TODO(qining): Emit an error message in the invalid case once the
177         // error handling is done.
178         return nullptr;
179       }
180     } else {
181       // If the spec constant target has been found, check if the candidate
182       // instruction is the same one as the target. The module is invalid if
183       // the candidate instruction is different with the found target.
184       // TODO(qining): Emit an error messaage in the invalid case once the
185       // error handling is done.
186       if (candidate_inst != target_inst) return nullptr;
187     }
188   }
189   return target_inst;
190 }
191 }  // namespace
192 
Process()193 Pass::Status SetSpecConstantDefaultValuePass::Process() {
194   // The operand index of decoration target in an OpDecorate instruction.
195   const uint32_t kTargetIdOperandIndex = 0;
196   // The operand index of the decoration literal in an OpDecorate instruction.
197   const uint32_t kDecorationOperandIndex = 1;
198   // The operand index of Spec id literal value in an OpDecorate SpecId
199   // instruction.
200   const uint32_t kSpecIdLiteralOperandIndex = 2;
201   // The number of operands in an OpDecorate SpecId instruction.
202   const uint32_t kOpDecorateSpecIdNumOperands = 3;
203   // The in-operand index of the default value in a OpSpecConstant instruction.
204   const uint32_t kOpSpecConstantLiteralInOperandIndex = 0;
205 
206   bool modified = false;
207   // Scan through all the annotation instructions to find 'OpDecorate SpecId'
208   // instructions. Then extract the decoration target of those instructions.
209   // The decoration targets should be spec constant defining instructions with
210   // opcode: OpSpecConstant{|True|False}. The spec id of those spec constants
211   // will be used to look up their new default values in the mapping from
212   // spec id to new default value strings. Once a new default value string
213   // is found for a spec id, the string will be parsed according to the target
214   // spec constant type. The parsed value will be used to replace the original
215   // default value of the target spec constant.
216   for (Instruction& inst : context()->annotations()) {
217     // Only process 'OpDecorate SpecId' instructions
218     if (inst.opcode() != SpvOp::SpvOpDecorate) continue;
219     if (inst.NumOperands() != kOpDecorateSpecIdNumOperands) continue;
220     if (inst.GetSingleWordInOperand(kDecorationOperandIndex) !=
221         uint32_t(SpvDecoration::SpvDecorationSpecId)) {
222       continue;
223     }
224 
225     // 'inst' is an OpDecorate SpecId instruction.
226     uint32_t spec_id = inst.GetSingleWordOperand(kSpecIdLiteralOperandIndex);
227     uint32_t target_id = inst.GetSingleWordOperand(kTargetIdOperandIndex);
228 
229     // Find the spec constant defining instruction. Note that the
230     // target_id might be a decoration group id.
231     Instruction* spec_inst = nullptr;
232     if (Instruction* target_inst = get_def_use_mgr()->GetDef(target_id)) {
233       if (target_inst->opcode() == SpvOp::SpvOpDecorationGroup) {
234         spec_inst =
235             GetSpecIdTargetFromDecorationGroup(*target_inst, get_def_use_mgr());
236       } else {
237         spec_inst = target_inst;
238       }
239     } else {
240       continue;
241     }
242     if (!spec_inst) continue;
243 
244     // Get the default value bit pattern for this spec id.
245     std::vector<uint32_t> bit_pattern;
246 
247     if (spec_id_to_value_str_.size() != 0) {
248       // Search for the new string-form default value for this spec id.
249       auto iter = spec_id_to_value_str_.find(spec_id);
250       if (iter == spec_id_to_value_str_.end()) {
251         continue;
252       }
253 
254       // Gets the string of the default value and parses it to bit pattern
255       // with the type of the spec constant.
256       const std::string& default_value_str = iter->second;
257       bit_pattern = ParseDefaultValueStr(
258           default_value_str.c_str(),
259           context()->get_type_mgr()->GetType(spec_inst->type_id()));
260 
261     } else {
262       // Search for the new bit-pattern-form default value for this spec id.
263       auto iter = spec_id_to_value_bit_pattern_.find(spec_id);
264       if (iter == spec_id_to_value_bit_pattern_.end()) {
265         continue;
266       }
267 
268       // Gets the bit-pattern of the default value from the map directly.
269       bit_pattern = ParseDefaultValueBitPattern(
270           iter->second,
271           context()->get_type_mgr()->GetType(spec_inst->type_id()));
272     }
273 
274     if (bit_pattern.empty()) continue;
275 
276     // Update the operand bit patterns of the spec constant defining
277     // instruction.
278     switch (spec_inst->opcode()) {
279       case SpvOp::SpvOpSpecConstant:
280         // If the new value is the same with the original value, no
281         // need to do anything. Otherwise update the operand words.
282         if (spec_inst->GetInOperand(kOpSpecConstantLiteralInOperandIndex)
283                 .words != bit_pattern) {
284           spec_inst->SetInOperand(kOpSpecConstantLiteralInOperandIndex,
285                                   std::move(bit_pattern));
286           modified = true;
287         }
288         break;
289       case SpvOp::SpvOpSpecConstantTrue:
290         // If the new value is also 'true', no need to change anything.
291         // Otherwise, set the opcode to OpSpecConstantFalse;
292         if (!static_cast<bool>(bit_pattern.front())) {
293           spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantFalse);
294           modified = true;
295         }
296         break;
297       case SpvOp::SpvOpSpecConstantFalse:
298         // If the new value is also 'false', no need to change anything.
299         // Otherwise, set the opcode to OpSpecConstantTrue;
300         if (static_cast<bool>(bit_pattern.front())) {
301           spec_inst->SetOpcode(SpvOp::SpvOpSpecConstantTrue);
302           modified = true;
303         }
304         break;
305       default:
306         break;
307     }
308     // No need to update the DefUse manager, as this pass does not change any
309     // ids.
310   }
311   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
312 }
313 
314 // Returns true if the given char is ':', '\0' or considered as blank space
315 // (i.e.: '\n', '\r', '\v', '\t', '\f' and ' ').
IsSeparator(char ch)316 bool IsSeparator(char ch) {
317   return std::strchr(":\0", ch) || std::isspace(ch) != 0;
318 }
319 
320 std::unique_ptr<SetSpecConstantDefaultValuePass::SpecIdToValueStrMap>
ParseDefaultValuesString(const char * str)321 SetSpecConstantDefaultValuePass::ParseDefaultValuesString(const char* str) {
322   if (!str) return nullptr;
323 
324   auto spec_id_to_value = MakeUnique<SpecIdToValueStrMap>();
325 
326   // The parsing loop, break when points to the end.
327   while (*str) {
328     // Find the spec id.
329     while (std::isspace(*str)) str++;  // skip leading spaces.
330     const char* entry_begin = str;
331     while (!IsSeparator(*str)) str++;
332     const char* entry_end = str;
333     std::string spec_id_str(entry_begin, entry_end - entry_begin);
334     uint32_t spec_id = 0;
335     if (!ParseNumber(spec_id_str.c_str(), &spec_id)) {
336       // The spec id is not a valid uint32 number.
337       return nullptr;
338     }
339     auto iter = spec_id_to_value->find(spec_id);
340     if (iter != spec_id_to_value->end()) {
341       // Same spec id has been defined before
342       return nullptr;
343     }
344     // Find the ':', spaces between the spec id and the ':' are not allowed.
345     if (*str++ != ':') {
346       // ':' not found
347       return nullptr;
348     }
349     // Find the value string
350     const char* val_begin = str;
351     while (!IsSeparator(*str)) str++;
352     const char* val_end = str;
353     if (val_end == val_begin) {
354       // Value string is empty.
355       return nullptr;
356     }
357     // Update the mapping with spec id and value string.
358     (*spec_id_to_value)[spec_id] = std::string(val_begin, val_end - val_begin);
359 
360     // Skip trailing spaces.
361     while (std::isspace(*str)) str++;
362   }
363 
364   return spec_id_to_value;
365 }
366 
367 }  // namespace opt
368 }  // namespace spvtools
369