1 /* Copyright (c) 2015-2019 The Khronos Group Inc.
2  * Copyright (c) 2015-2019 Valve Corporation
3  * Copyright (c) 2015-2019 LunarG, Inc.
4  * Copyright (C) 2015-2019 Google Inc.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  *
18  * Author: Chris Forbes <chrisf@ijw.co.nz>
19  */
20 #ifndef VULKAN_SHADER_VALIDATION_H
21 #define VULKAN_SHADER_VALIDATION_H
22 
23 #include <unordered_map>
24 
25 #include <SPIRV/spirv.hpp>
26 #include <generated/spirv_tools_commit_id.h>
27 #include "spirv-tools/optimizer.hpp"
28 
29 // A forward iterator over spirv instructions. Provides easy access to len, opcode, and content words
30 // without the caller needing to care too much about the physical SPIRV module layout.
31 struct spirv_inst_iter {
32     std::vector<uint32_t>::const_iterator zero;
33     std::vector<uint32_t>::const_iterator it;
34 
lenspirv_inst_iter35     uint32_t len() const {
36         auto result = *it >> 16;
37         assert(result > 0);
38         return result;
39     }
40 
opcodespirv_inst_iter41     uint32_t opcode() { return *it & 0x0ffffu; }
42 
wordspirv_inst_iter43     uint32_t const &word(unsigned n) const {
44         assert(n < len());
45         return it[n];
46     }
47 
offsetspirv_inst_iter48     uint32_t offset() { return (uint32_t)(it - zero); }
49 
spirv_inst_iterspirv_inst_iter50     spirv_inst_iter() {}
51 
spirv_inst_iterspirv_inst_iter52     spirv_inst_iter(std::vector<uint32_t>::const_iterator zero, std::vector<uint32_t>::const_iterator it) : zero(zero), it(it) {}
53 
54     bool operator==(spirv_inst_iter const &other) const { return it == other.it; }
55 
56     bool operator!=(spirv_inst_iter const &other) const { return it != other.it; }
57 
58     spirv_inst_iter operator++(int) {  // x++
59         spirv_inst_iter ii = *this;
60         it += len();
61         return ii;
62     }
63 
64     spirv_inst_iter operator++() {  // ++x;
65         it += len();
66         return *this;
67     }
68 
69     // The iterator and the value are the same thing.
70     spirv_inst_iter &operator*() { return *this; }
71     spirv_inst_iter const &operator*() const { return *this; }
72 };
73 
74 struct decoration_set {
75     enum {
76         location_bit = 1 << 0,
77         patch_bit = 1 << 1,
78         relaxed_precision_bit = 1 << 2,
79         block_bit = 1 << 3,
80         buffer_block_bit = 1 << 4,
81         component_bit = 1 << 5,
82         input_attachment_index_bit = 1 << 6,
83         descriptor_set_bit = 1 << 7,
84         binding_bit = 1 << 8,
85         nonwritable_bit = 1 << 9,
86         builtin_bit = 1 << 10,
87     };
88     uint32_t flags = 0;
89     uint32_t location = static_cast<uint32_t>(-1);
90     uint32_t component = 0;
91     uint32_t input_attachment_index = 0;
92     uint32_t descriptor_set = 0;
93     uint32_t binding = 0;
94     uint32_t builtin = static_cast<uint32_t>(-1);
95 
mergedecoration_set96     void merge(decoration_set const &other) {
97         if (other.flags & location_bit) location = other.location;
98         if (other.flags & component_bit) component = other.component;
99         if (other.flags & input_attachment_index_bit) input_attachment_index = other.input_attachment_index;
100         if (other.flags & descriptor_set_bit) descriptor_set = other.descriptor_set;
101         if (other.flags & binding_bit) binding = other.binding;
102         if (other.flags & builtin_bit) builtin = other.builtin;
103         flags |= other.flags;
104     }
105 
106     void add(uint32_t decoration, uint32_t value);
107 };
108 
109 struct SHADER_MODULE_STATE {
110     // The spirv image itself
111     std::vector<uint32_t> words;
112     // A mapping of <id> to the first word of its def. this is useful because walking type
113     // trees, constant expressions, etc requires jumping all over the instruction stream.
114     std::unordered_map<unsigned, unsigned> def_index;
115     std::unordered_map<unsigned, decoration_set> decorations;
116     struct EntryPoint {
117         uint32_t offset;
118         VkShaderStageFlags stage;
119     };
120     std::unordered_multimap<std::string, EntryPoint> entry_points;
121     bool has_valid_spirv;
122     VkShaderModule vk_shader_module;
123     uint32_t gpu_validation_shader_id;
124 
PreprocessShaderBinarySHADER_MODULE_STATE125     std::vector<uint32_t> PreprocessShaderBinary(uint32_t *src_binary, size_t binary_size, spv_target_env env) {
126         std::vector<uint32_t> src(src_binary, src_binary + binary_size / sizeof(uint32_t));
127 
128         // Check if there are any group decoration instructions, and flatten them if found.
129         bool has_group_decoration = false;
130         bool done = false;
131 
132         // Walk through the first part of the SPIR-V module, looking for group decoration instructions.
133         // Skip the header (5 words).
134         auto itr = spirv_inst_iter(src.begin(), src.begin() + 5);
135         auto itrend = spirv_inst_iter(src.begin(), src.end());
136         while (itr != itrend && !done) {
137             spv::Op opcode = (spv::Op)itr.opcode();
138             switch (opcode) {
139                 case spv::OpDecorationGroup:
140                 case spv::OpGroupDecorate:
141                 case spv::OpGroupMemberDecorate:
142                     has_group_decoration = true;
143                     done = true;
144                     break;
145                 case spv::OpFunction:
146                     // An OpFunction indicates there are no more decorations
147                     done = true;
148                     break;
149                 default:
150                     break;
151             }
152             itr++;
153         }
154 
155         if (has_group_decoration) {
156             spvtools::Optimizer optimizer(env);
157             optimizer.RegisterPass(spvtools::CreateFlattenDecorationPass());
158             std::vector<uint32_t> optimized_binary;
159             // Run optimizer to flatten decorations only, set skip_validation so as to not re-run validator
160             auto result =
161                 optimizer.Run(src_binary, binary_size / sizeof(uint32_t), &optimized_binary, spvtools::ValidatorOptions(), true);
162             if (result) {
163                 return optimized_binary;
164             }
165         }
166         // Return the original module.
167         return src;
168     }
169 
SHADER_MODULE_STATESHADER_MODULE_STATE170     SHADER_MODULE_STATE(VkShaderModuleCreateInfo const *pCreateInfo, VkShaderModule shaderModule, spv_target_env env,
171                         uint32_t unique_shader_id)
172         : words(PreprocessShaderBinary((uint32_t *)pCreateInfo->pCode, pCreateInfo->codeSize, env)),
173           def_index(),
174           has_valid_spirv(true),
175           vk_shader_module(shaderModule),
176           gpu_validation_shader_id(unique_shader_id) {
177         BuildDefIndex();
178     }
179 
SHADER_MODULE_STATESHADER_MODULE_STATE180     SHADER_MODULE_STATE() : has_valid_spirv(false), vk_shader_module(VK_NULL_HANDLE) {}
181 
get_decorationsSHADER_MODULE_STATE182     decoration_set get_decorations(unsigned id) const {
183         // return the actual decorations for this id, or a default set.
184         auto it = decorations.find(id);
185         if (it != decorations.end()) return it->second;
186         return decoration_set();
187     }
188 
189     // Expose begin() / end() to enable range-based for
beginSHADER_MODULE_STATE190     spirv_inst_iter begin() const { return spirv_inst_iter(words.begin(), words.begin() + 5); }  // First insn
endSHADER_MODULE_STATE191     spirv_inst_iter end() const { return spirv_inst_iter(words.begin(), words.end()); }          // Just past last insn
192     // Given an offset into the module, produce an iterator there.
atSHADER_MODULE_STATE193     spirv_inst_iter at(unsigned offset) const { return spirv_inst_iter(words.begin(), words.begin() + offset); }
194 
195     // Gets an iterator to the definition of an id
get_defSHADER_MODULE_STATE196     spirv_inst_iter get_def(unsigned id) const {
197         auto it = def_index.find(id);
198         if (it == def_index.end()) {
199             return end();
200         }
201         return at(it->second);
202     }
203 
204     void BuildDefIndex();
205 };
206 
207 class ValidationCache {
208     // hashes of shaders that have passed validation before, and can be skipped.
209     // we don't store negative results, as we would have to also store what was
210     // wrong with them; also, we expect they will get fixed, so we're less
211     // likely to see them again.
212     std::unordered_set<uint32_t> good_shader_hashes;
ValidationCache()213     ValidationCache() {}
214 
215    public:
Create(VkValidationCacheCreateInfoEXT const * pCreateInfo)216     static VkValidationCacheEXT Create(VkValidationCacheCreateInfoEXT const *pCreateInfo) {
217         auto cache = new ValidationCache();
218         cache->Load(pCreateInfo);
219         return VkValidationCacheEXT(cache);
220     }
221 
Load(VkValidationCacheCreateInfoEXT const * pCreateInfo)222     void Load(VkValidationCacheCreateInfoEXT const *pCreateInfo) {
223         const auto headerSize = 2 * sizeof(uint32_t) + VK_UUID_SIZE;
224         auto size = headerSize;
225         if (!pCreateInfo->pInitialData || pCreateInfo->initialDataSize < size) return;
226 
227         uint32_t const *data = (uint32_t const *)pCreateInfo->pInitialData;
228         if (data[0] != size) return;
229         if (data[1] != VK_VALIDATION_CACHE_HEADER_VERSION_ONE_EXT) return;
230         uint8_t expected_uuid[VK_UUID_SIZE];
231         Sha1ToVkUuid(SPIRV_TOOLS_COMMIT_ID, expected_uuid);
232         if (memcmp(&data[2], expected_uuid, VK_UUID_SIZE) != 0) return;  // different version
233 
234         data = (uint32_t const *)(reinterpret_cast<uint8_t const *>(data) + headerSize);
235 
236         for (; size < pCreateInfo->initialDataSize; data++, size += sizeof(uint32_t)) {
237             good_shader_hashes.insert(*data);
238         }
239     }
240 
Write(size_t * pDataSize,void * pData)241     void Write(size_t *pDataSize, void *pData) {
242         const auto headerSize = 2 * sizeof(uint32_t) + VK_UUID_SIZE;  // 4 bytes for header size + 4 bytes for version number + UUID
243         if (!pData) {
244             *pDataSize = headerSize + good_shader_hashes.size() * sizeof(uint32_t);
245             return;
246         }
247 
248         if (*pDataSize < headerSize) {
249             *pDataSize = 0;
250             return;  // Too small for even the header!
251         }
252 
253         uint32_t *out = (uint32_t *)pData;
254         size_t actualSize = headerSize;
255 
256         // Write the header
257         *out++ = headerSize;
258         *out++ = VK_VALIDATION_CACHE_HEADER_VERSION_ONE_EXT;
259         Sha1ToVkUuid(SPIRV_TOOLS_COMMIT_ID, reinterpret_cast<uint8_t *>(out));
260         out = (uint32_t *)(reinterpret_cast<uint8_t *>(out) + VK_UUID_SIZE);
261 
262         for (auto it = good_shader_hashes.begin(); it != good_shader_hashes.end() && actualSize < *pDataSize;
263              it++, out++, actualSize += sizeof(uint32_t)) {
264             *out = *it;
265         }
266 
267         *pDataSize = actualSize;
268     }
269 
Merge(ValidationCache const * other)270     void Merge(ValidationCache const *other) {
271         good_shader_hashes.reserve(good_shader_hashes.size() + other->good_shader_hashes.size());
272         for (auto h : other->good_shader_hashes) good_shader_hashes.insert(h);
273     }
274 
275     static uint32_t MakeShaderHash(VkShaderModuleCreateInfo const *smci);
276 
Contains(uint32_t hash)277     bool Contains(uint32_t hash) { return good_shader_hashes.count(hash) != 0; }
278 
Insert(uint32_t hash)279     void Insert(uint32_t hash) { good_shader_hashes.insert(hash); }
280 
281    private:
Sha1ToVkUuid(const char * sha1_str,uint8_t uuid[VK_UUID_SIZE])282     void Sha1ToVkUuid(const char *sha1_str, uint8_t uuid[VK_UUID_SIZE]) {
283         // Convert sha1_str from a hex string to binary. We only need VK_UUID_BYTES of
284         // output, so pad with zeroes if the input string is shorter than that, and truncate
285         // if it's longer.
286         char padded_sha1_str[2 * VK_UUID_SIZE + 1] = {};
287         strncpy(padded_sha1_str, sha1_str, 2 * VK_UUID_SIZE + 1);
288         char byte_str[3] = {};
289         for (uint32_t i = 0; i < VK_UUID_SIZE; ++i) {
290             byte_str[0] = padded_sha1_str[2 * i + 0];
291             byte_str[1] = padded_sha1_str[2 * i + 1];
292             uuid[i] = static_cast<uint8_t>(strtol(byte_str, NULL, 16));
293         }
294     }
295 };
296 
297 #endif  // VULKAN_SHADER_VALIDATION_H
298