1 //===- SymbolRewriter.cpp - Symbol Rewriter ---------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // SymbolRewriter is a LLVM pass which can rewrite symbols transparently within
11 // existing code.  It is implemented as a compiler pass and is configured via a
12 // YAML configuration file.
13 //
14 // The YAML configuration file format is as follows:
15 //
16 // RewriteMapFile := RewriteDescriptors
17 // RewriteDescriptors := RewriteDescriptor | RewriteDescriptors
18 // RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}'
19 // RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields
20 // RewriteDescriptorField := FieldIdentifier ':' FieldValue ','
21 // RewriteDescriptorType := Identifier
22 // FieldIdentifier := Identifier
23 // FieldValue := Identifier
24 // Identifier := [0-9a-zA-Z]+
25 //
26 // Currently, the following descriptor types are supported:
27 //
28 // - function:          (function rewriting)
29 //      + Source        (original name of the function)
30 //      + Target        (explicit transformation)
31 //      + Transform     (pattern transformation)
32 //      + Naked         (boolean, whether the function is undecorated)
33 // - global variable:   (external linkage global variable rewriting)
34 //      + Source        (original name of externally visible variable)
35 //      + Target        (explicit transformation)
36 //      + Transform     (pattern transformation)
37 // - global alias:      (global alias rewriting)
38 //      + Source        (original name of the aliased name)
39 //      + Target        (explicit transformation)
40 //      + Transform     (pattern transformation)
41 //
42 // Note that source and exactly one of [Target, Transform] must be provided
43 //
44 // New rewrite descriptors can be created.  Addding a new rewrite descriptor
45 // involves:
46 //
47 //  a) extended the rewrite descriptor kind enumeration
48 //     (<anonymous>::RewriteDescriptor::RewriteDescriptorType)
49 //  b) implementing the new descriptor
50 //     (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor)
51 //  c) extending the rewrite map parser
52 //     (<anonymous>::RewriteMapParser::parseEntry)
53 //
54 //  Specify to rewrite the symbols using the `-rewrite-symbols` option, and
55 //  specify the map file to use for the rewriting via the `-rewrite-map-file`
56 //  option.
57 //
58 //===----------------------------------------------------------------------===//
59 
60 #define DEBUG_TYPE "symbol-rewriter"
61 #include "llvm/Pass.h"
62 #include "llvm/ADT/SmallString.h"
63 #include "llvm/IR/LegacyPassManager.h"
64 #include "llvm/Support/CommandLine.h"
65 #include "llvm/Support/Debug.h"
66 #include "llvm/Support/MemoryBuffer.h"
67 #include "llvm/Support/Regex.h"
68 #include "llvm/Support/SourceMgr.h"
69 #include "llvm/Support/YAMLParser.h"
70 #include "llvm/Support/raw_ostream.h"
71 #include "llvm/Transforms/Utils/SymbolRewriter.h"
72 
73 using namespace llvm;
74 using namespace SymbolRewriter;
75 
76 static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
77                                              cl::desc("Symbol Rewrite Map"),
78                                              cl::value_desc("filename"));
79 
rewriteComdat(Module & M,GlobalObject * GO,const std::string & Source,const std::string & Target)80 static void rewriteComdat(Module &M, GlobalObject *GO,
81                           const std::string &Source,
82                           const std::string &Target) {
83   if (Comdat *CD = GO->getComdat()) {
84     auto &Comdats = M.getComdatSymbolTable();
85 
86     Comdat *C = M.getOrInsertComdat(Target);
87     C->setSelectionKind(CD->getSelectionKind());
88     GO->setComdat(C);
89 
90     Comdats.erase(Comdats.find(Source));
91   }
92 }
93 
94 namespace {
95 template <RewriteDescriptor::Type DT, typename ValueType,
96           ValueType *(llvm::Module::*Get)(StringRef) const>
97 class ExplicitRewriteDescriptor : public RewriteDescriptor {
98 public:
99   const std::string Source;
100   const std::string Target;
101 
ExplicitRewriteDescriptor(StringRef S,StringRef T,const bool Naked)102   ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked)
103       : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S),
104         Target(T) {}
105 
106   bool performOnModule(Module &M) override;
107 
classof(const RewriteDescriptor * RD)108   static bool classof(const RewriteDescriptor *RD) {
109     return RD->getType() == DT;
110   }
111 };
112 
113 template <RewriteDescriptor::Type DT, typename ValueType,
114           ValueType *(llvm::Module::*Get)(StringRef) const>
performOnModule(Module & M)115 bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
116   bool Changed = false;
117   if (ValueType *S = (M.*Get)(Source)) {
118     if (GlobalObject *GO = dyn_cast<GlobalObject>(S))
119       rewriteComdat(M, GO, Source, Target);
120 
121     if (Value *T = (M.*Get)(Target))
122       S->setValueName(T->getValueName());
123     else
124       S->setName(Target);
125 
126     Changed = true;
127   }
128   return Changed;
129 }
130 
131 template <RewriteDescriptor::Type DT, typename ValueType,
132           ValueType *(llvm::Module::*Get)(StringRef) const,
133           iterator_range<typename iplist<ValueType>::iterator>
134           (llvm::Module::*Iterator)()>
135 class PatternRewriteDescriptor : public RewriteDescriptor {
136 public:
137   const std::string Pattern;
138   const std::string Transform;
139 
PatternRewriteDescriptor(StringRef P,StringRef T)140   PatternRewriteDescriptor(StringRef P, StringRef T)
141     : RewriteDescriptor(DT), Pattern(P), Transform(T) { }
142 
143   bool performOnModule(Module &M) override;
144 
classof(const RewriteDescriptor * RD)145   static bool classof(const RewriteDescriptor *RD) {
146     return RD->getType() == DT;
147   }
148 };
149 
150 template <RewriteDescriptor::Type DT, typename ValueType,
151           ValueType *(llvm::Module::*Get)(StringRef) const,
152           iterator_range<typename iplist<ValueType>::iterator>
153           (llvm::Module::*Iterator)()>
154 bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>::
performOnModule(Module & M)155 performOnModule(Module &M) {
156   bool Changed = false;
157   for (auto &C : (M.*Iterator)()) {
158     std::string Error;
159 
160     std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error);
161     if (!Error.empty())
162       report_fatal_error("unable to transforn " + C.getName() + " in " +
163                          M.getModuleIdentifier() + ": " + Error);
164 
165     if (C.getName() == Name)
166       continue;
167 
168     if (GlobalObject *GO = dyn_cast<GlobalObject>(&C))
169       rewriteComdat(M, GO, C.getName(), Name);
170 
171     if (Value *V = (M.*Get)(Name))
172       C.setValueName(V->getValueName());
173     else
174       C.setName(Name);
175 
176     Changed = true;
177   }
178   return Changed;
179 }
180 
181 /// Represents a rewrite for an explicitly named (function) symbol.  Both the
182 /// source function name and target function name of the transformation are
183 /// explicitly spelt out.
184 typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function,
185                                   llvm::Function, &llvm::Module::getFunction>
186     ExplicitRewriteFunctionDescriptor;
187 
188 /// Represents a rewrite for an explicitly named (global variable) symbol.  Both
189 /// the source variable name and target variable name are spelt out.  This
190 /// applies only to module level variables.
191 typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
192                                   llvm::GlobalVariable,
193                                   &llvm::Module::getGlobalVariable>
194     ExplicitRewriteGlobalVariableDescriptor;
195 
196 /// Represents a rewrite for an explicitly named global alias.  Both the source
197 /// and target name are explicitly spelt out.
198 typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
199                                   llvm::GlobalAlias,
200                                   &llvm::Module::getNamedAlias>
201     ExplicitRewriteNamedAliasDescriptor;
202 
203 /// Represents a rewrite for a regular expression based pattern for functions.
204 /// A pattern for the function name is provided and a transformation for that
205 /// pattern to determine the target function name create the rewrite rule.
206 typedef PatternRewriteDescriptor<RewriteDescriptor::Type::Function,
207                                  llvm::Function, &llvm::Module::getFunction,
208                                  &llvm::Module::functions>
209     PatternRewriteFunctionDescriptor;
210 
211 /// Represents a rewrite for a global variable based upon a matching pattern.
212 /// Each global variable matching the provided pattern will be transformed as
213 /// described in the transformation pattern for the target.  Applies only to
214 /// module level variables.
215 typedef PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
216                                  llvm::GlobalVariable,
217                                  &llvm::Module::getGlobalVariable,
218                                  &llvm::Module::globals>
219     PatternRewriteGlobalVariableDescriptor;
220 
221 /// PatternRewriteNamedAliasDescriptor - represents a rewrite for global
222 /// aliases which match a given pattern.  The provided transformation will be
223 /// applied to each of the matching names.
224 typedef PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
225                                  llvm::GlobalAlias,
226                                  &llvm::Module::getNamedAlias,
227                                  &llvm::Module::aliases>
228     PatternRewriteNamedAliasDescriptor;
229 } // namespace
230 
parse(const std::string & MapFile,RewriteDescriptorList * DL)231 bool RewriteMapParser::parse(const std::string &MapFile,
232                              RewriteDescriptorList *DL) {
233   ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping =
234       MemoryBuffer::getFile(MapFile);
235 
236   if (!Mapping)
237     report_fatal_error("unable to read rewrite map '" + MapFile + "': " +
238                        Mapping.getError().message());
239 
240   if (!parse(*Mapping, DL))
241     report_fatal_error("unable to parse rewrite map '" + MapFile + "'");
242 
243   return true;
244 }
245 
parse(std::unique_ptr<MemoryBuffer> & MapFile,RewriteDescriptorList * DL)246 bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile,
247                              RewriteDescriptorList *DL) {
248   SourceMgr SM;
249   yaml::Stream YS(MapFile->getBuffer(), SM);
250 
251   for (auto &Document : YS) {
252     yaml::MappingNode *DescriptorList;
253 
254     // ignore empty documents
255     if (isa<yaml::NullNode>(Document.getRoot()))
256       continue;
257 
258     DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot());
259     if (!DescriptorList) {
260       YS.printError(Document.getRoot(), "DescriptorList node must be a map");
261       return false;
262     }
263 
264     for (auto &Descriptor : *DescriptorList)
265       if (!parseEntry(YS, Descriptor, DL))
266         return false;
267   }
268 
269   return true;
270 }
271 
parseEntry(yaml::Stream & YS,yaml::KeyValueNode & Entry,RewriteDescriptorList * DL)272 bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry,
273                                   RewriteDescriptorList *DL) {
274   yaml::ScalarNode *Key;
275   yaml::MappingNode *Value;
276   SmallString<32> KeyStorage;
277   StringRef RewriteType;
278 
279   Key = dyn_cast<yaml::ScalarNode>(Entry.getKey());
280   if (!Key) {
281     YS.printError(Entry.getKey(), "rewrite type must be a scalar");
282     return false;
283   }
284 
285   Value = dyn_cast<yaml::MappingNode>(Entry.getValue());
286   if (!Value) {
287     YS.printError(Entry.getValue(), "rewrite descriptor must be a map");
288     return false;
289   }
290 
291   RewriteType = Key->getValue(KeyStorage);
292   if (RewriteType.equals("function"))
293     return parseRewriteFunctionDescriptor(YS, Key, Value, DL);
294   else if (RewriteType.equals("global variable"))
295     return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL);
296   else if (RewriteType.equals("global alias"))
297     return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL);
298 
299   YS.printError(Entry.getKey(), "unknown rewrite type");
300   return false;
301 }
302 
303 bool RewriteMapParser::
parseRewriteFunctionDescriptor(yaml::Stream & YS,yaml::ScalarNode * K,yaml::MappingNode * Descriptor,RewriteDescriptorList * DL)304 parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
305                                yaml::MappingNode *Descriptor,
306                                RewriteDescriptorList *DL) {
307   bool Naked = false;
308   std::string Source;
309   std::string Target;
310   std::string Transform;
311 
312   for (auto &Field : *Descriptor) {
313     yaml::ScalarNode *Key;
314     yaml::ScalarNode *Value;
315     SmallString<32> KeyStorage;
316     SmallString<32> ValueStorage;
317     StringRef KeyValue;
318 
319     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
320     if (!Key) {
321       YS.printError(Field.getKey(), "descriptor key must be a scalar");
322       return false;
323     }
324 
325     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
326     if (!Value) {
327       YS.printError(Field.getValue(), "descriptor value must be a scalar");
328       return false;
329     }
330 
331     KeyValue = Key->getValue(KeyStorage);
332     if (KeyValue.equals("source")) {
333       std::string Error;
334 
335       Source = Value->getValue(ValueStorage);
336       if (!Regex(Source).isValid(Error)) {
337         YS.printError(Field.getKey(), "invalid regex: " + Error);
338         return false;
339       }
340     } else if (KeyValue.equals("target")) {
341       Target = Value->getValue(ValueStorage);
342     } else if (KeyValue.equals("transform")) {
343       Transform = Value->getValue(ValueStorage);
344     } else if (KeyValue.equals("naked")) {
345       std::string Undecorated;
346 
347       Undecorated = Value->getValue(ValueStorage);
348       Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1";
349     } else {
350       YS.printError(Field.getKey(), "unknown key for function");
351       return false;
352     }
353   }
354 
355   if (Transform.empty() == Target.empty()) {
356     YS.printError(Descriptor,
357                   "exactly one of transform or target must be specified");
358     return false;
359   }
360 
361   // TODO see if there is a more elegant solution to selecting the rewrite
362   // descriptor type
363   if (!Target.empty())
364     DL->push_back(new ExplicitRewriteFunctionDescriptor(Source, Target, Naked));
365   else
366     DL->push_back(new PatternRewriteFunctionDescriptor(Source, Transform));
367 
368   return true;
369 }
370 
371 bool RewriteMapParser::
parseRewriteGlobalVariableDescriptor(yaml::Stream & YS,yaml::ScalarNode * K,yaml::MappingNode * Descriptor,RewriteDescriptorList * DL)372 parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
373                                      yaml::MappingNode *Descriptor,
374                                      RewriteDescriptorList *DL) {
375   std::string Source;
376   std::string Target;
377   std::string Transform;
378 
379   for (auto &Field : *Descriptor) {
380     yaml::ScalarNode *Key;
381     yaml::ScalarNode *Value;
382     SmallString<32> KeyStorage;
383     SmallString<32> ValueStorage;
384     StringRef KeyValue;
385 
386     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
387     if (!Key) {
388       YS.printError(Field.getKey(), "descriptor Key must be a scalar");
389       return false;
390     }
391 
392     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
393     if (!Value) {
394       YS.printError(Field.getValue(), "descriptor value must be a scalar");
395       return false;
396     }
397 
398     KeyValue = Key->getValue(KeyStorage);
399     if (KeyValue.equals("source")) {
400       std::string Error;
401 
402       Source = Value->getValue(ValueStorage);
403       if (!Regex(Source).isValid(Error)) {
404         YS.printError(Field.getKey(), "invalid regex: " + Error);
405         return false;
406       }
407     } else if (KeyValue.equals("target")) {
408       Target = Value->getValue(ValueStorage);
409     } else if (KeyValue.equals("transform")) {
410       Transform = Value->getValue(ValueStorage);
411     } else {
412       YS.printError(Field.getKey(), "unknown Key for Global Variable");
413       return false;
414     }
415   }
416 
417   if (Transform.empty() == Target.empty()) {
418     YS.printError(Descriptor,
419                   "exactly one of transform or target must be specified");
420     return false;
421   }
422 
423   if (!Target.empty())
424     DL->push_back(new ExplicitRewriteGlobalVariableDescriptor(Source, Target,
425                                                               /*Naked*/false));
426   else
427     DL->push_back(new PatternRewriteGlobalVariableDescriptor(Source,
428                                                              Transform));
429 
430   return true;
431 }
432 
433 bool RewriteMapParser::
parseRewriteGlobalAliasDescriptor(yaml::Stream & YS,yaml::ScalarNode * K,yaml::MappingNode * Descriptor,RewriteDescriptorList * DL)434 parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
435                                   yaml::MappingNode *Descriptor,
436                                   RewriteDescriptorList *DL) {
437   std::string Source;
438   std::string Target;
439   std::string Transform;
440 
441   for (auto &Field : *Descriptor) {
442     yaml::ScalarNode *Key;
443     yaml::ScalarNode *Value;
444     SmallString<32> KeyStorage;
445     SmallString<32> ValueStorage;
446     StringRef KeyValue;
447 
448     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
449     if (!Key) {
450       YS.printError(Field.getKey(), "descriptor key must be a scalar");
451       return false;
452     }
453 
454     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
455     if (!Value) {
456       YS.printError(Field.getValue(), "descriptor value must be a scalar");
457       return false;
458     }
459 
460     KeyValue = Key->getValue(KeyStorage);
461     if (KeyValue.equals("source")) {
462       std::string Error;
463 
464       Source = Value->getValue(ValueStorage);
465       if (!Regex(Source).isValid(Error)) {
466         YS.printError(Field.getKey(), "invalid regex: " + Error);
467         return false;
468       }
469     } else if (KeyValue.equals("target")) {
470       Target = Value->getValue(ValueStorage);
471     } else if (KeyValue.equals("transform")) {
472       Transform = Value->getValue(ValueStorage);
473     } else {
474       YS.printError(Field.getKey(), "unknown key for Global Alias");
475       return false;
476     }
477   }
478 
479   if (Transform.empty() == Target.empty()) {
480     YS.printError(Descriptor,
481                   "exactly one of transform or target must be specified");
482     return false;
483   }
484 
485   if (!Target.empty())
486     DL->push_back(new ExplicitRewriteNamedAliasDescriptor(Source, Target,
487                                                           /*Naked*/false));
488   else
489     DL->push_back(new PatternRewriteNamedAliasDescriptor(Source, Transform));
490 
491   return true;
492 }
493 
494 namespace {
495 class RewriteSymbols : public ModulePass {
496 public:
497   static char ID; // Pass identification, replacement for typeid
498 
499   RewriteSymbols();
500   RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL);
501 
502   bool runOnModule(Module &M) override;
503 
504 private:
505   void loadAndParseMapFiles();
506 
507   SymbolRewriter::RewriteDescriptorList Descriptors;
508 };
509 
510 char RewriteSymbols::ID = 0;
511 
RewriteSymbols()512 RewriteSymbols::RewriteSymbols() : ModulePass(ID) {
513   initializeRewriteSymbolsPass(*PassRegistry::getPassRegistry());
514   loadAndParseMapFiles();
515 }
516 
RewriteSymbols(SymbolRewriter::RewriteDescriptorList & DL)517 RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL)
518     : ModulePass(ID) {
519   Descriptors.splice(Descriptors.begin(), DL);
520 }
521 
runOnModule(Module & M)522 bool RewriteSymbols::runOnModule(Module &M) {
523   bool Changed;
524 
525   Changed = false;
526   for (auto &Descriptor : Descriptors)
527     Changed |= Descriptor.performOnModule(M);
528 
529   return Changed;
530 }
531 
loadAndParseMapFiles()532 void RewriteSymbols::loadAndParseMapFiles() {
533   const std::vector<std::string> MapFiles(RewriteMapFiles);
534   SymbolRewriter::RewriteMapParser parser;
535 
536   for (const auto &MapFile : MapFiles)
537     parser.parse(MapFile, &Descriptors);
538 }
539 }
540 
541 INITIALIZE_PASS(RewriteSymbols, "rewrite-symbols", "Rewrite Symbols", false,
542                 false)
543 
createRewriteSymbolsPass()544 ModulePass *llvm::createRewriteSymbolsPass() { return new RewriteSymbols(); }
545 
546 ModulePass *
createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList & DL)547 llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) {
548   return new RewriteSymbols(DL);
549 }
550