1 // Copyright (c) 2017 Pierre Moreau
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "spirv-tools/linker.hpp"
16 
17 #include <algorithm>
18 #include <cstdio>
19 #include <cstring>
20 #include <iostream>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 
28 #include "source/assembly_grammar.h"
29 #include "source/diagnostic.h"
30 #include "source/opt/build_module.h"
31 #include "source/opt/compact_ids_pass.h"
32 #include "source/opt/decoration_manager.h"
33 #include "source/opt/ir_loader.h"
34 #include "source/opt/pass_manager.h"
35 #include "source/opt/remove_duplicates_pass.h"
36 #include "source/spirv_target_env.h"
37 #include "source/util/make_unique.h"
38 #include "spirv-tools/libspirv.hpp"
39 
40 namespace spvtools {
41 namespace {
42 
43 using opt::IRContext;
44 using opt::Instruction;
45 using opt::Module;
46 using opt::Operand;
47 using opt::PassManager;
48 using opt::RemoveDuplicatesPass;
49 using opt::analysis::DecorationManager;
50 using opt::analysis::DefUseManager;
51 
52 // Stores various information about an imported or exported symbol.
53 struct LinkageSymbolInfo {
54   SpvId id;          // ID of the symbol
55   SpvId type_id;     // ID of the type of the symbol
56   std::string name;  // unique name defining the symbol and used for matching
57                      // imports and exports together
58   std::vector<SpvId> parameter_ids;  // ID of the parameters of the symbol, if
59                                      // it is a function
60 };
61 struct LinkageEntry {
62   LinkageSymbolInfo imported_symbol;
63   LinkageSymbolInfo exported_symbol;
64 
LinkageEntryspvtools::__anon724612b70111::LinkageEntry65   LinkageEntry(const LinkageSymbolInfo& import_info,
66                const LinkageSymbolInfo& export_info)
67       : imported_symbol(import_info), exported_symbol(export_info) {}
68 };
69 using LinkageTable = std::vector<LinkageEntry>;
70 
71 // Shifts the IDs used in each binary of |modules| so that they occupy a
72 // disjoint range from the other binaries, and compute the new ID bound which
73 // is returned in |max_id_bound|.
74 //
75 // Both |modules| and |max_id_bound| should not be null, and |modules| should
76 // not be empty either. Furthermore |modules| should not contain any null
77 // pointers.
78 spv_result_t ShiftIdsInModules(const MessageConsumer& consumer,
79                                std::vector<opt::Module*>* modules,
80                                uint32_t* max_id_bound);
81 
82 // Generates the header for the linked module and returns it in |header|.
83 //
84 // |header| should not be null, |modules| should not be empty and pointers
85 // should be non-null. |max_id_bound| should be strictly greater than 0.
86 //
87 // TODO(pierremoreau): What to do when binaries use different versions of
88 //                     SPIR-V? For now, use the max of all versions found in
89 //                     the input modules.
90 spv_result_t GenerateHeader(const MessageConsumer& consumer,
91                             const std::vector<opt::Module*>& modules,
92                             uint32_t max_id_bound, opt::ModuleHeader* header);
93 
94 // Merge all the modules from |in_modules| into a single module owned by
95 // |linked_context|.
96 //
97 // |linked_context| should not be null.
98 spv_result_t MergeModules(const MessageConsumer& consumer,
99                           const std::vector<Module*>& in_modules,
100                           const AssemblyGrammar& grammar,
101                           IRContext* linked_context);
102 
103 // Compute all pairs of import and export and return it in |linkings_to_do|.
104 //
105 // |linkings_to_do should not be null. Built-in symbols will be ignored.
106 //
107 // TODO(pierremoreau): Linkage attributes applied by a group decoration are
108 //                     currently not handled. (You could have a group being
109 //                     applied to a single ID.)
110 // TODO(pierremoreau): What should be the proper behaviour with built-in
111 //                     symbols?
112 spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
113                                   const opt::IRContext& linked_context,
114                                   const DefUseManager& def_use_manager,
115                                   const DecorationManager& decoration_manager,
116                                   bool allow_partial_linkage,
117                                   LinkageTable* linkings_to_do);
118 
119 // Checks that for each pair of import and export, the import and export have
120 // the same type as well as the same decorations.
121 //
122 // TODO(pierremoreau): Decorations on functions parameters are currently not
123 // checked.
124 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
125                                             const LinkageTable& linkings_to_do,
126                                             opt::IRContext* context);
127 
128 // Remove linkage specific instructions, such as prototypes of imported
129 // functions, declarations of imported variables, import (and export if
130 // necessary) linkage attribtes.
131 //
132 // |linked_context| and |decoration_manager| should not be null, and the
133 // 'RemoveDuplicatePass' should be run first.
134 //
135 // TODO(pierremoreau): Linkage attributes applied by a group decoration are
136 //                     currently not handled. (You could have a group being
137 //                     applied to a single ID.)
138 // TODO(pierremoreau): Run a pass for removing dead instructions, for example
139 //                     OpName for prototypes of imported funcions.
140 spv_result_t RemoveLinkageSpecificInstructions(
141     const MessageConsumer& consumer, const LinkerOptions& options,
142     const LinkageTable& linkings_to_do, DecorationManager* decoration_manager,
143     opt::IRContext* linked_context);
144 
145 // Verify that the unique ids of each instruction in |linked_context| (i.e. the
146 // merged module) are truly unique. Does not check the validity of other ids
147 spv_result_t VerifyIds(const MessageConsumer& consumer,
148                        opt::IRContext* linked_context);
149 
ShiftIdsInModules(const MessageConsumer & consumer,std::vector<opt::Module * > * modules,uint32_t * max_id_bound)150 spv_result_t ShiftIdsInModules(const MessageConsumer& consumer,
151                                std::vector<opt::Module*>* modules,
152                                uint32_t* max_id_bound) {
153   spv_position_t position = {};
154 
155   if (modules == nullptr)
156     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
157            << "|modules| of ShiftIdsInModules should not be null.";
158   if (modules->empty())
159     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
160            << "|modules| of ShiftIdsInModules should not be empty.";
161   if (max_id_bound == nullptr)
162     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
163            << "|max_id_bound| of ShiftIdsInModules should not be null.";
164 
165   uint32_t id_bound = modules->front()->IdBound() - 1u;
166   for (auto module_iter = modules->begin() + 1; module_iter != modules->end();
167        ++module_iter) {
168     Module* module = *module_iter;
169     module->ForEachInst([&id_bound](Instruction* insn) {
170       insn->ForEachId([&id_bound](uint32_t* id) { *id += id_bound; });
171     });
172     id_bound += module->IdBound() - 1u;
173     if (id_bound > 0x3FFFFF)
174       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_ID)
175              << "The limit of IDs, 4194303, was exceeded:"
176              << " " << id_bound << " is the current ID bound.";
177 
178     // Invalidate the DefUseManager
179     module->context()->InvalidateAnalyses(opt::IRContext::kAnalysisDefUse);
180   }
181   ++id_bound;
182   if (id_bound > 0x3FFFFF)
183     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_ID)
184            << "The limit of IDs, 4194303, was exceeded:"
185            << " " << id_bound << " is the current ID bound.";
186 
187   *max_id_bound = id_bound;
188 
189   return SPV_SUCCESS;
190 }
191 
GenerateHeader(const MessageConsumer & consumer,const std::vector<opt::Module * > & modules,uint32_t max_id_bound,opt::ModuleHeader * header)192 spv_result_t GenerateHeader(const MessageConsumer& consumer,
193                             const std::vector<opt::Module*>& modules,
194                             uint32_t max_id_bound, opt::ModuleHeader* header) {
195   spv_position_t position = {};
196 
197   if (modules.empty())
198     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
199            << "|modules| of GenerateHeader should not be empty.";
200   if (max_id_bound == 0u)
201     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
202            << "|max_id_bound| of GenerateHeader should not be null.";
203 
204   uint32_t version = 0u;
205   for (const auto& module : modules)
206     version = std::max(version, module->version());
207 
208   header->magic_number = SpvMagicNumber;
209   header->version = version;
210   header->generator = 17u;
211   header->bound = max_id_bound;
212   header->reserved = 0u;
213 
214   return SPV_SUCCESS;
215 }
216 
MergeModules(const MessageConsumer & consumer,const std::vector<Module * > & input_modules,const AssemblyGrammar & grammar,IRContext * linked_context)217 spv_result_t MergeModules(const MessageConsumer& consumer,
218                           const std::vector<Module*>& input_modules,
219                           const AssemblyGrammar& grammar,
220                           IRContext* linked_context) {
221   spv_position_t position = {};
222 
223   if (linked_context == nullptr)
224     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
225            << "|linked_module| of MergeModules should not be null.";
226   Module* linked_module = linked_context->module();
227 
228   if (input_modules.empty()) return SPV_SUCCESS;
229 
230   for (const auto& module : input_modules)
231     for (const auto& inst : module->capabilities())
232       linked_module->AddCapability(
233           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
234 
235   for (const auto& module : input_modules)
236     for (const auto& inst : module->extensions())
237       linked_module->AddExtension(
238           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
239 
240   for (const auto& module : input_modules)
241     for (const auto& inst : module->ext_inst_imports())
242       linked_module->AddExtInstImport(
243           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
244 
245   do {
246     const Instruction* memory_model_inst = input_modules[0]->GetMemoryModel();
247     if (memory_model_inst == nullptr) break;
248 
249     uint32_t addressing_model = memory_model_inst->GetSingleWordOperand(0u);
250     uint32_t memory_model = memory_model_inst->GetSingleWordOperand(1u);
251     for (const auto& module : input_modules) {
252       memory_model_inst = module->GetMemoryModel();
253       if (memory_model_inst == nullptr) continue;
254 
255       if (addressing_model != memory_model_inst->GetSingleWordOperand(0u)) {
256         spv_operand_desc initial_desc = nullptr, current_desc = nullptr;
257         grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL,
258                               addressing_model, &initial_desc);
259         grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL,
260                               memory_model_inst->GetSingleWordOperand(0u),
261                               &current_desc);
262         return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
263                << "Conflicting addressing models: " << initial_desc->name
264                << " vs " << current_desc->name << ".";
265       }
266       if (memory_model != memory_model_inst->GetSingleWordOperand(1u)) {
267         spv_operand_desc initial_desc = nullptr, current_desc = nullptr;
268         grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL, memory_model,
269                               &initial_desc);
270         grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL,
271                               memory_model_inst->GetSingleWordOperand(1u),
272                               &current_desc);
273         return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
274                << "Conflicting memory models: " << initial_desc->name << " vs "
275                << current_desc->name << ".";
276       }
277     }
278 
279     if (memory_model_inst != nullptr)
280       linked_module->SetMemoryModel(std::unique_ptr<Instruction>(
281           memory_model_inst->Clone(linked_context)));
282   } while (false);
283 
284   std::vector<std::pair<uint32_t, const char*>> entry_points;
285   for (const auto& module : input_modules)
286     for (const auto& inst : module->entry_points()) {
287       const uint32_t model = inst.GetSingleWordInOperand(0);
288       const char* const name =
289           reinterpret_cast<const char*>(inst.GetInOperand(2).words.data());
290       const auto i = std::find_if(
291           entry_points.begin(), entry_points.end(),
292           [model, name](const std::pair<uint32_t, const char*>& v) {
293             return v.first == model && strcmp(name, v.second) == 0;
294           });
295       if (i != entry_points.end()) {
296         spv_operand_desc desc = nullptr;
297         grammar.lookupOperand(SPV_OPERAND_TYPE_EXECUTION_MODEL, model, &desc);
298         return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
299                << "The entry point \"" << name << "\", with execution model "
300                << desc->name << ", was already defined.";
301       }
302       linked_module->AddEntryPoint(
303           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
304       entry_points.emplace_back(model, name);
305     }
306 
307   for (const auto& module : input_modules)
308     for (const auto& inst : module->execution_modes())
309       linked_module->AddExecutionMode(
310           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
311 
312   for (const auto& module : input_modules)
313     for (const auto& inst : module->debugs1())
314       linked_module->AddDebug1Inst(
315           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
316 
317   for (const auto& module : input_modules)
318     for (const auto& inst : module->debugs2())
319       linked_module->AddDebug2Inst(
320           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
321 
322   for (const auto& module : input_modules)
323     for (const auto& inst : module->debugs3())
324       linked_module->AddDebug3Inst(
325           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
326 
327   // If the generated module uses SPIR-V 1.1 or higher, add an
328   // OpModuleProcessed instruction about the linking step.
329   if (linked_module->version() >= 0x10100) {
330     const std::string processed_string("Linked by SPIR-V Tools Linker");
331     const auto num_chars = processed_string.size();
332     // Compute num words, accommodate the terminating null character.
333     const auto num_words = (num_chars + 1 + 3) / 4;
334     std::vector<uint32_t> processed_words(num_words, 0u);
335     std::memcpy(processed_words.data(), processed_string.data(), num_chars);
336     linked_module->AddDebug3Inst(std::unique_ptr<Instruction>(
337         new Instruction(linked_context, SpvOpModuleProcessed, 0u, 0u,
338                         {{SPV_OPERAND_TYPE_LITERAL_STRING, processed_words}})));
339   }
340 
341   for (const auto& module : input_modules)
342     for (const auto& inst : module->annotations())
343       linked_module->AddAnnotationInst(
344           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
345 
346   // TODO(pierremoreau): Since the modules have not been validate, should we
347   //                     expect SpvStorageClassFunction variables outside
348   //                     functions?
349   uint32_t num_global_values = 0u;
350   for (const auto& module : input_modules) {
351     for (const auto& inst : module->types_values()) {
352       linked_module->AddType(
353           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
354       num_global_values += inst.opcode() == SpvOpVariable;
355     }
356   }
357   if (num_global_values > 0xFFFF)
358     return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
359            << "The limit of global values, 65535, was exceeded;"
360            << " " << num_global_values << " global values were found.";
361 
362   // Process functions and their basic blocks
363   for (const auto& module : input_modules) {
364     for (const auto& func : *module) {
365       std::unique_ptr<opt::Function> cloned_func(func.Clone(linked_context));
366       linked_module->AddFunction(std::move(cloned_func));
367     }
368   }
369 
370   return SPV_SUCCESS;
371 }
372 
GetImportExportPairs(const MessageConsumer & consumer,const opt::IRContext & linked_context,const DefUseManager & def_use_manager,const DecorationManager & decoration_manager,bool allow_partial_linkage,LinkageTable * linkings_to_do)373 spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
374                                   const opt::IRContext& linked_context,
375                                   const DefUseManager& def_use_manager,
376                                   const DecorationManager& decoration_manager,
377                                   bool allow_partial_linkage,
378                                   LinkageTable* linkings_to_do) {
379   spv_position_t position = {};
380 
381   if (linkings_to_do == nullptr)
382     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
383            << "|linkings_to_do| of GetImportExportPairs should not be empty.";
384 
385   std::vector<LinkageSymbolInfo> imports;
386   std::unordered_map<std::string, std::vector<LinkageSymbolInfo>> exports;
387 
388   // Figure out the imports and exports
389   for (const auto& decoration : linked_context.annotations()) {
390     if (decoration.opcode() != SpvOpDecorate ||
391         decoration.GetSingleWordInOperand(1u) != SpvDecorationLinkageAttributes)
392       continue;
393 
394     const SpvId id = decoration.GetSingleWordInOperand(0u);
395     // Ignore if the targeted symbol is a built-in
396     bool is_built_in = false;
397     for (const auto& id_decoration :
398          decoration_manager.GetDecorationsFor(id, false)) {
399       if (id_decoration->GetSingleWordInOperand(1u) == SpvDecorationBuiltIn) {
400         is_built_in = true;
401         break;
402       }
403     }
404     if (is_built_in) {
405       continue;
406     }
407 
408     const uint32_t type = decoration.GetSingleWordInOperand(3u);
409 
410     LinkageSymbolInfo symbol_info;
411     symbol_info.name =
412         reinterpret_cast<const char*>(decoration.GetInOperand(2u).words.data());
413     symbol_info.id = id;
414     symbol_info.type_id = 0u;
415 
416     // Retrieve the type of the current symbol. This information will be used
417     // when checking that the imported and exported symbols have the same
418     // types.
419     const Instruction* def_inst = def_use_manager.GetDef(id);
420     if (def_inst == nullptr)
421       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
422              << "ID " << id << " is never defined:\n";
423 
424     if (def_inst->opcode() == SpvOpVariable) {
425       symbol_info.type_id = def_inst->type_id();
426     } else if (def_inst->opcode() == SpvOpFunction) {
427       symbol_info.type_id = def_inst->GetSingleWordInOperand(1u);
428 
429       // range-based for loop calls begin()/end(), but never cbegin()/cend(),
430       // which will not work here.
431       for (auto func_iter = linked_context.module()->cbegin();
432            func_iter != linked_context.module()->cend(); ++func_iter) {
433         if (func_iter->result_id() != id) continue;
434         func_iter->ForEachParam([&symbol_info](const Instruction* inst) {
435           symbol_info.parameter_ids.push_back(inst->result_id());
436         });
437       }
438     } else {
439       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
440              << "Only global variables and functions can be decorated using"
441              << " LinkageAttributes; " << id << " is neither of them.\n";
442     }
443 
444     if (type == SpvLinkageTypeImport)
445       imports.push_back(symbol_info);
446     else if (type == SpvLinkageTypeExport)
447       exports[symbol_info.name].push_back(symbol_info);
448   }
449 
450   // Find the import/export pairs
451   for (const auto& import : imports) {
452     std::vector<LinkageSymbolInfo> possible_exports;
453     const auto& exp = exports.find(import.name);
454     if (exp != exports.end()) possible_exports = exp->second;
455     if (possible_exports.empty() && !allow_partial_linkage)
456       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
457              << "Unresolved external reference to \"" << import.name << "\".";
458     else if (possible_exports.size() > 1u)
459       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
460              << "Too many external references, " << possible_exports.size()
461              << ", were found for \"" << import.name << "\".";
462 
463     if (!possible_exports.empty())
464       linkings_to_do->emplace_back(import, possible_exports.front());
465   }
466 
467   return SPV_SUCCESS;
468 }
469 
CheckImportExportCompatibility(const MessageConsumer & consumer,const LinkageTable & linkings_to_do,opt::IRContext * context)470 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
471                                             const LinkageTable& linkings_to_do,
472                                             opt::IRContext* context) {
473   spv_position_t position = {};
474 
475   // Ensure th import and export types are the same.
476   const DefUseManager& def_use_manager = *context->get_def_use_mgr();
477   const DecorationManager& decoration_manager = *context->get_decoration_mgr();
478   for (const auto& linking_entry : linkings_to_do) {
479     if (!RemoveDuplicatesPass::AreTypesEqual(
480             *def_use_manager.GetDef(linking_entry.imported_symbol.type_id),
481             *def_use_manager.GetDef(linking_entry.exported_symbol.type_id),
482             context))
483       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
484              << "Type mismatch on symbol \""
485              << linking_entry.imported_symbol.name
486              << "\" between imported variable/function %"
487              << linking_entry.imported_symbol.id
488              << " and exported variable/function %"
489              << linking_entry.exported_symbol.id << ".";
490   }
491 
492   // Ensure the import and export decorations are similar
493   for (const auto& linking_entry : linkings_to_do) {
494     if (!decoration_manager.HaveTheSameDecorations(
495             linking_entry.imported_symbol.id, linking_entry.exported_symbol.id))
496       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
497              << "Decorations mismatch on symbol \""
498              << linking_entry.imported_symbol.name
499              << "\" between imported variable/function %"
500              << linking_entry.imported_symbol.id
501              << " and exported variable/function %"
502              << linking_entry.exported_symbol.id << ".";
503     // TODO(pierremoreau): Decorations on function parameters should probably
504     //                     match, except for FuncParamAttr if I understand the
505     //                     spec correctly.
506     // TODO(pierremoreau): Decorations on the function return type should
507     //                     match, except for FuncParamAttr.
508   }
509 
510   return SPV_SUCCESS;
511 }
512 
RemoveLinkageSpecificInstructions(const MessageConsumer & consumer,const LinkerOptions & options,const LinkageTable & linkings_to_do,DecorationManager * decoration_manager,opt::IRContext * linked_context)513 spv_result_t RemoveLinkageSpecificInstructions(
514     const MessageConsumer& consumer, const LinkerOptions& options,
515     const LinkageTable& linkings_to_do, DecorationManager* decoration_manager,
516     opt::IRContext* linked_context) {
517   spv_position_t position = {};
518 
519   if (decoration_manager == nullptr)
520     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
521            << "|decoration_manager| of RemoveLinkageSpecificInstructions "
522               "should not be empty.";
523   if (linked_context == nullptr)
524     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
525            << "|linked_module| of RemoveLinkageSpecificInstructions should not "
526               "be empty.";
527 
528   // TODO(pierremoreau): Remove FuncParamAttr decorations of imported
529   // functions' return type.
530 
531   // Remove FuncParamAttr decorations of imported functions' parameters.
532   // From the SPIR-V specification, Sec. 2.13:
533   //   When resolving imported functions, the Function Control and all Function
534   //   Parameter Attributes are taken from the function definition, and not
535   //   from the function declaration.
536   for (const auto& linking_entry : linkings_to_do) {
537     for (const auto parameter_id :
538          linking_entry.imported_symbol.parameter_ids) {
539       decoration_manager->RemoveDecorationsFrom(
540           parameter_id, [](const Instruction& inst) {
541             return (inst.opcode() == SpvOpDecorate ||
542                     inst.opcode() == SpvOpMemberDecorate) &&
543                    inst.GetSingleWordInOperand(1u) ==
544                        SpvDecorationFuncParamAttr;
545           });
546     }
547   }
548 
549   // Remove prototypes of imported functions
550   for (const auto& linking_entry : linkings_to_do) {
551     for (auto func_iter = linked_context->module()->begin();
552          func_iter != linked_context->module()->end();) {
553       if (func_iter->result_id() == linking_entry.imported_symbol.id)
554         func_iter = func_iter.Erase();
555       else
556         ++func_iter;
557     }
558   }
559 
560   // Remove declarations of imported variables
561   for (const auto& linking_entry : linkings_to_do) {
562     auto next = linked_context->types_values_begin();
563     for (auto inst = next; inst != linked_context->types_values_end();
564          inst = next) {
565       ++next;
566       if (inst->result_id() == linking_entry.imported_symbol.id) {
567         linked_context->KillInst(&*inst);
568       }
569     }
570   }
571 
572   // If partial linkage is allowed, we need an efficient way to check whether
573   // an imported ID had a corresponding export symbol. As uses of the imported
574   // symbol have already been replaced by the exported symbol, use the exported
575   // symbol ID.
576   // TODO(pierremoreau): This will not work if the decoration is applied
577   //                     through a group, but the linker does not support that
578   //                     either.
579   std::unordered_set<SpvId> imports;
580   if (options.GetAllowPartialLinkage()) {
581     imports.reserve(linkings_to_do.size());
582     for (const auto& linking_entry : linkings_to_do)
583       imports.emplace(linking_entry.exported_symbol.id);
584   }
585 
586   // Remove import linkage attributes
587   auto next = linked_context->annotation_begin();
588   for (auto inst = next; inst != linked_context->annotation_end();
589        inst = next) {
590     ++next;
591     // If this is an import annotation:
592     // * if we do not allow partial linkage, remove all import annotations;
593     // * otherwise, remove the annotation only if there was a corresponding
594     //   export.
595     if (inst->opcode() == SpvOpDecorate &&
596         inst->GetSingleWordOperand(1u) == SpvDecorationLinkageAttributes &&
597         inst->GetSingleWordOperand(3u) == SpvLinkageTypeImport &&
598         (!options.GetAllowPartialLinkage() ||
599          imports.find(inst->GetSingleWordOperand(0u)) != imports.end())) {
600       linked_context->KillInst(&*inst);
601     }
602   }
603 
604   // Remove export linkage attributes if making an executable
605   if (!options.GetCreateLibrary()) {
606     next = linked_context->annotation_begin();
607     for (auto inst = next; inst != linked_context->annotation_end();
608          inst = next) {
609       ++next;
610       if (inst->opcode() == SpvOpDecorate &&
611           inst->GetSingleWordOperand(1u) == SpvDecorationLinkageAttributes &&
612           inst->GetSingleWordOperand(3u) == SpvLinkageTypeExport) {
613         linked_context->KillInst(&*inst);
614       }
615     }
616   }
617 
618   // Remove Linkage capability if making an executable and partial linkage is
619   // not allowed
620   if (!options.GetCreateLibrary() && !options.GetAllowPartialLinkage()) {
621     for (auto& inst : linked_context->capabilities())
622       if (inst.GetSingleWordInOperand(0u) == SpvCapabilityLinkage) {
623         linked_context->KillInst(&inst);
624         // The RemoveDuplicatesPass did remove duplicated capabilities, so we
625         // now there aren’t more SpvCapabilityLinkage further down.
626         break;
627       }
628   }
629 
630   return SPV_SUCCESS;
631 }
632 
VerifyIds(const MessageConsumer & consumer,opt::IRContext * linked_context)633 spv_result_t VerifyIds(const MessageConsumer& consumer,
634                        opt::IRContext* linked_context) {
635   std::unordered_set<uint32_t> ids;
636   bool ok = true;
637   linked_context->module()->ForEachInst(
638       [&ids, &ok](const opt::Instruction* inst) {
639         ok &= ids.insert(inst->unique_id()).second;
640       });
641 
642   if (!ok) {
643     consumer(SPV_MSG_INTERNAL_ERROR, "", {}, "Non-unique id in merged module");
644     return SPV_ERROR_INVALID_ID;
645   }
646 
647   return SPV_SUCCESS;
648 }
649 
650 }  // namespace
651 
Link(const Context & context,const std::vector<std::vector<uint32_t>> & binaries,std::vector<uint32_t> * linked_binary,const LinkerOptions & options)652 spv_result_t Link(const Context& context,
653                   const std::vector<std::vector<uint32_t>>& binaries,
654                   std::vector<uint32_t>* linked_binary,
655                   const LinkerOptions& options) {
656   std::vector<const uint32_t*> binary_ptrs;
657   binary_ptrs.reserve(binaries.size());
658   std::vector<size_t> binary_sizes;
659   binary_sizes.reserve(binaries.size());
660 
661   for (const auto& binary : binaries) {
662     binary_ptrs.push_back(binary.data());
663     binary_sizes.push_back(binary.size());
664   }
665 
666   return Link(context, binary_ptrs.data(), binary_sizes.data(), binaries.size(),
667               linked_binary, options);
668 }
669 
Link(const Context & context,const uint32_t * const * binaries,const size_t * binary_sizes,size_t num_binaries,std::vector<uint32_t> * linked_binary,const LinkerOptions & options)670 spv_result_t Link(const Context& context, const uint32_t* const* binaries,
671                   const size_t* binary_sizes, size_t num_binaries,
672                   std::vector<uint32_t>* linked_binary,
673                   const LinkerOptions& options) {
674   spv_position_t position = {};
675   const spv_context& c_context = context.CContext();
676   const MessageConsumer& consumer = c_context->consumer;
677 
678   linked_binary->clear();
679   if (num_binaries == 0u)
680     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
681            << "No modules were given.";
682 
683   std::vector<std::unique_ptr<IRContext>> ir_contexts;
684   std::vector<Module*> modules;
685   modules.reserve(num_binaries);
686   for (size_t i = 0u; i < num_binaries; ++i) {
687     const uint32_t schema = binaries[i][4u];
688     if (schema != 0u) {
689       position.index = 4u;
690       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
691              << "Schema is non-zero for module " << i << ".";
692     }
693 
694     std::unique_ptr<IRContext> ir_context = BuildModule(
695         c_context->target_env, consumer, binaries[i], binary_sizes[i]);
696     if (ir_context == nullptr)
697       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
698              << "Failed to build a module out of " << ir_contexts.size() << ".";
699     modules.push_back(ir_context->module());
700     ir_contexts.push_back(std::move(ir_context));
701   }
702 
703   // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint
704   //          range from the other binaries, and compute the new ID bound.
705   uint32_t max_id_bound = 0u;
706   spv_result_t res = ShiftIdsInModules(consumer, &modules, &max_id_bound);
707   if (res != SPV_SUCCESS) return res;
708 
709   // Phase 2: Generate the header
710   opt::ModuleHeader header;
711   res = GenerateHeader(consumer, modules, max_id_bound, &header);
712   if (res != SPV_SUCCESS) return res;
713   IRContext linked_context(c_context->target_env, consumer);
714   linked_context.module()->SetHeader(header);
715 
716   // Phase 3: Merge all the binaries into a single one.
717   AssemblyGrammar grammar(c_context);
718   res = MergeModules(consumer, modules, grammar, &linked_context);
719   if (res != SPV_SUCCESS) return res;
720 
721   if (options.GetVerifyIds()) {
722     res = VerifyIds(consumer, &linked_context);
723     if (res != SPV_SUCCESS) return res;
724   }
725 
726   // Phase 4: Find the import/export pairs
727   LinkageTable linkings_to_do;
728   res = GetImportExportPairs(consumer, linked_context,
729                              *linked_context.get_def_use_mgr(),
730                              *linked_context.get_decoration_mgr(),
731                              options.GetAllowPartialLinkage(), &linkings_to_do);
732   if (res != SPV_SUCCESS) return res;
733 
734   // Phase 5: Ensure the import and export have the same types and decorations.
735   res =
736       CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context);
737   if (res != SPV_SUCCESS) return res;
738 
739   // Phase 6: Remove duplicates
740   PassManager manager;
741   manager.SetMessageConsumer(consumer);
742   manager.AddPass<RemoveDuplicatesPass>();
743   opt::Pass::Status pass_res = manager.Run(&linked_context);
744   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
745 
746   // Phase 7: Rematch import variables/functions to export variables/functions
747   for (const auto& linking_entry : linkings_to_do)
748     linked_context.ReplaceAllUsesWith(linking_entry.imported_symbol.id,
749                                       linking_entry.exported_symbol.id);
750 
751   // Phase 8: Remove linkage specific instructions, such as import/export
752   // attributes, linkage capability, etc. if applicable
753   res = RemoveLinkageSpecificInstructions(consumer, options, linkings_to_do,
754                                           linked_context.get_decoration_mgr(),
755                                           &linked_context);
756   if (res != SPV_SUCCESS) return res;
757 
758   // Phase 9: Compact the IDs used in the module
759   manager.AddPass<opt::CompactIdsPass>();
760   pass_res = manager.Run(&linked_context);
761   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
762 
763   // Phase 10: Output the module
764   linked_context.module()->ToBinary(linked_binary, true);
765 
766   return SPV_SUCCESS;
767 }
768 
769 }  // namespace spvtools
770