1 //===- SymbolRewriter.cpp - Symbol Rewriter ---------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // SymbolRewriter is a LLVM pass which can rewrite symbols transparently within
11 // existing code.  It is implemented as a compiler pass and is configured via a
12 // YAML configuration file.
13 //
14 // The YAML configuration file format is as follows:
15 //
16 // RewriteMapFile := RewriteDescriptors
17 // RewriteDescriptors := RewriteDescriptor | RewriteDescriptors
18 // RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}'
19 // RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields
20 // RewriteDescriptorField := FieldIdentifier ':' FieldValue ','
21 // RewriteDescriptorType := Identifier
22 // FieldIdentifier := Identifier
23 // FieldValue := Identifier
24 // Identifier := [0-9a-zA-Z]+
25 //
26 // Currently, the following descriptor types are supported:
27 //
28 // - function:          (function rewriting)
29 //      + Source        (original name of the function)
30 //      + Target        (explicit transformation)
31 //      + Transform     (pattern transformation)
32 //      + Naked         (boolean, whether the function is undecorated)
33 // - global variable:   (external linkage global variable rewriting)
34 //      + Source        (original name of externally visible variable)
35 //      + Target        (explicit transformation)
36 //      + Transform     (pattern transformation)
37 // - global alias:      (global alias rewriting)
38 //      + Source        (original name of the aliased name)
39 //      + Target        (explicit transformation)
40 //      + Transform     (pattern transformation)
41 //
42 // Note that source and exactly one of [Target, Transform] must be provided
43 //
44 // New rewrite descriptors can be created.  Addding a new rewrite descriptor
45 // involves:
46 //
47 //  a) extended the rewrite descriptor kind enumeration
48 //     (<anonymous>::RewriteDescriptor::RewriteDescriptorType)
49 //  b) implementing the new descriptor
50 //     (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor)
51 //  c) extending the rewrite map parser
52 //     (<anonymous>::RewriteMapParser::parseEntry)
53 //
54 //  Specify to rewrite the symbols using the `-rewrite-symbols` option, and
55 //  specify the map file to use for the rewriting via the `-rewrite-map-file`
56 //  option.
57 //
58 //===----------------------------------------------------------------------===//
59 
60 #define DEBUG_TYPE "symbol-rewriter"
61 #include "llvm/CodeGen/Passes.h"
62 #include "llvm/Pass.h"
63 #include "llvm/ADT/SmallString.h"
64 #include "llvm/IR/LegacyPassManager.h"
65 #include "llvm/Support/CommandLine.h"
66 #include "llvm/Support/Debug.h"
67 #include "llvm/Support/MemoryBuffer.h"
68 #include "llvm/Support/Regex.h"
69 #include "llvm/Support/SourceMgr.h"
70 #include "llvm/Support/YAMLParser.h"
71 #include "llvm/Support/raw_ostream.h"
72 #include "llvm/Transforms/Utils/SymbolRewriter.h"
73 
74 using namespace llvm;
75 using namespace SymbolRewriter;
76 
77 static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
78                                              cl::desc("Symbol Rewrite Map"),
79                                              cl::value_desc("filename"));
80 
rewriteComdat(Module & M,GlobalObject * GO,const std::string & Source,const std::string & Target)81 static void rewriteComdat(Module &M, GlobalObject *GO,
82                           const std::string &Source,
83                           const std::string &Target) {
84   if (Comdat *CD = GO->getComdat()) {
85     auto &Comdats = M.getComdatSymbolTable();
86 
87     Comdat *C = M.getOrInsertComdat(Target);
88     C->setSelectionKind(CD->getSelectionKind());
89     GO->setComdat(C);
90 
91     Comdats.erase(Comdats.find(Source));
92   }
93 }
94 
95 namespace {
96 template <RewriteDescriptor::Type DT, typename ValueType,
97           ValueType *(llvm::Module::*Get)(StringRef) const>
98 class ExplicitRewriteDescriptor : public RewriteDescriptor {
99 public:
100   const std::string Source;
101   const std::string Target;
102 
ExplicitRewriteDescriptor(StringRef S,StringRef T,const bool Naked)103   ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked)
104       : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S),
105         Target(T) {}
106 
107   bool performOnModule(Module &M) override;
108 
classof(const RewriteDescriptor * RD)109   static bool classof(const RewriteDescriptor *RD) {
110     return RD->getType() == DT;
111   }
112 };
113 
114 template <RewriteDescriptor::Type DT, typename ValueType,
115           ValueType *(llvm::Module::*Get)(StringRef) const>
performOnModule(Module & M)116 bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
117   bool Changed = false;
118   if (ValueType *S = (M.*Get)(Source)) {
119     if (GlobalObject *GO = dyn_cast<GlobalObject>(S))
120       rewriteComdat(M, GO, Source, Target);
121 
122     if (Value *T = (M.*Get)(Target))
123       S->setValueName(T->getValueName());
124     else
125       S->setName(Target);
126 
127     Changed = true;
128   }
129   return Changed;
130 }
131 
132 template <RewriteDescriptor::Type DT, typename ValueType,
133           ValueType *(llvm::Module::*Get)(StringRef) const,
134           iterator_range<typename iplist<ValueType>::iterator>
135           (llvm::Module::*Iterator)()>
136 class PatternRewriteDescriptor : public RewriteDescriptor {
137 public:
138   const std::string Pattern;
139   const std::string Transform;
140 
PatternRewriteDescriptor(StringRef P,StringRef T)141   PatternRewriteDescriptor(StringRef P, StringRef T)
142     : RewriteDescriptor(DT), Pattern(P), Transform(T) { }
143 
144   bool performOnModule(Module &M) override;
145 
classof(const RewriteDescriptor * RD)146   static bool classof(const RewriteDescriptor *RD) {
147     return RD->getType() == DT;
148   }
149 };
150 
151 template <RewriteDescriptor::Type DT, typename ValueType,
152           ValueType *(llvm::Module::*Get)(StringRef) const,
153           iterator_range<typename iplist<ValueType>::iterator>
154           (llvm::Module::*Iterator)()>
155 bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>::
performOnModule(Module & M)156 performOnModule(Module &M) {
157   bool Changed = false;
158   for (auto &C : (M.*Iterator)()) {
159     std::string Error;
160 
161     std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error);
162     if (!Error.empty())
163       report_fatal_error("unable to transforn " + C.getName() + " in " +
164                          M.getModuleIdentifier() + ": " + Error);
165 
166     if (C.getName() == Name)
167       continue;
168 
169     if (GlobalObject *GO = dyn_cast<GlobalObject>(&C))
170       rewriteComdat(M, GO, C.getName(), Name);
171 
172     if (Value *V = (M.*Get)(Name))
173       C.setValueName(V->getValueName());
174     else
175       C.setName(Name);
176 
177     Changed = true;
178   }
179   return Changed;
180 }
181 
182 /// Represents a rewrite for an explicitly named (function) symbol.  Both the
183 /// source function name and target function name of the transformation are
184 /// explicitly spelt out.
185 typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function,
186                                   llvm::Function, &llvm::Module::getFunction>
187     ExplicitRewriteFunctionDescriptor;
188 
189 /// Represents a rewrite for an explicitly named (global variable) symbol.  Both
190 /// the source variable name and target variable name are spelt out.  This
191 /// applies only to module level variables.
192 typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
193                                   llvm::GlobalVariable,
194                                   &llvm::Module::getGlobalVariable>
195     ExplicitRewriteGlobalVariableDescriptor;
196 
197 /// Represents a rewrite for an explicitly named global alias.  Both the source
198 /// and target name are explicitly spelt out.
199 typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
200                                   llvm::GlobalAlias,
201                                   &llvm::Module::getNamedAlias>
202     ExplicitRewriteNamedAliasDescriptor;
203 
204 /// Represents a rewrite for a regular expression based pattern for functions.
205 /// A pattern for the function name is provided and a transformation for that
206 /// pattern to determine the target function name create the rewrite rule.
207 typedef PatternRewriteDescriptor<RewriteDescriptor::Type::Function,
208                                  llvm::Function, &llvm::Module::getFunction,
209                                  &llvm::Module::functions>
210     PatternRewriteFunctionDescriptor;
211 
212 /// Represents a rewrite for a global variable based upon a matching pattern.
213 /// Each global variable matching the provided pattern will be transformed as
214 /// described in the transformation pattern for the target.  Applies only to
215 /// module level variables.
216 typedef PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
217                                  llvm::GlobalVariable,
218                                  &llvm::Module::getGlobalVariable,
219                                  &llvm::Module::globals>
220     PatternRewriteGlobalVariableDescriptor;
221 
222 /// PatternRewriteNamedAliasDescriptor - represents a rewrite for global
223 /// aliases which match a given pattern.  The provided transformation will be
224 /// applied to each of the matching names.
225 typedef PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
226                                  llvm::GlobalAlias,
227                                  &llvm::Module::getNamedAlias,
228                                  &llvm::Module::aliases>
229     PatternRewriteNamedAliasDescriptor;
230 } // namespace
231 
parse(const std::string & MapFile,RewriteDescriptorList * DL)232 bool RewriteMapParser::parse(const std::string &MapFile,
233                              RewriteDescriptorList *DL) {
234   ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping =
235       MemoryBuffer::getFile(MapFile);
236 
237   if (!Mapping)
238     report_fatal_error("unable to read rewrite map '" + MapFile + "': " +
239                        Mapping.getError().message());
240 
241   if (!parse(*Mapping, DL))
242     report_fatal_error("unable to parse rewrite map '" + MapFile + "'");
243 
244   return true;
245 }
246 
parse(std::unique_ptr<MemoryBuffer> & MapFile,RewriteDescriptorList * DL)247 bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile,
248                              RewriteDescriptorList *DL) {
249   SourceMgr SM;
250   yaml::Stream YS(MapFile->getBuffer(), SM);
251 
252   for (auto &Document : YS) {
253     yaml::MappingNode *DescriptorList;
254 
255     // ignore empty documents
256     if (isa<yaml::NullNode>(Document.getRoot()))
257       continue;
258 
259     DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot());
260     if (!DescriptorList) {
261       YS.printError(Document.getRoot(), "DescriptorList node must be a map");
262       return false;
263     }
264 
265     for (auto &Descriptor : *DescriptorList)
266       if (!parseEntry(YS, Descriptor, DL))
267         return false;
268   }
269 
270   return true;
271 }
272 
parseEntry(yaml::Stream & YS,yaml::KeyValueNode & Entry,RewriteDescriptorList * DL)273 bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry,
274                                   RewriteDescriptorList *DL) {
275   yaml::ScalarNode *Key;
276   yaml::MappingNode *Value;
277   SmallString<32> KeyStorage;
278   StringRef RewriteType;
279 
280   Key = dyn_cast<yaml::ScalarNode>(Entry.getKey());
281   if (!Key) {
282     YS.printError(Entry.getKey(), "rewrite type must be a scalar");
283     return false;
284   }
285 
286   Value = dyn_cast<yaml::MappingNode>(Entry.getValue());
287   if (!Value) {
288     YS.printError(Entry.getValue(), "rewrite descriptor must be a map");
289     return false;
290   }
291 
292   RewriteType = Key->getValue(KeyStorage);
293   if (RewriteType.equals("function"))
294     return parseRewriteFunctionDescriptor(YS, Key, Value, DL);
295   else if (RewriteType.equals("global variable"))
296     return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL);
297   else if (RewriteType.equals("global alias"))
298     return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL);
299 
300   YS.printError(Entry.getKey(), "unknown rewrite type");
301   return false;
302 }
303 
304 bool RewriteMapParser::
parseRewriteFunctionDescriptor(yaml::Stream & YS,yaml::ScalarNode * K,yaml::MappingNode * Descriptor,RewriteDescriptorList * DL)305 parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
306                                yaml::MappingNode *Descriptor,
307                                RewriteDescriptorList *DL) {
308   bool Naked = false;
309   std::string Source;
310   std::string Target;
311   std::string Transform;
312 
313   for (auto &Field : *Descriptor) {
314     yaml::ScalarNode *Key;
315     yaml::ScalarNode *Value;
316     SmallString<32> KeyStorage;
317     SmallString<32> ValueStorage;
318     StringRef KeyValue;
319 
320     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
321     if (!Key) {
322       YS.printError(Field.getKey(), "descriptor key must be a scalar");
323       return false;
324     }
325 
326     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
327     if (!Value) {
328       YS.printError(Field.getValue(), "descriptor value must be a scalar");
329       return false;
330     }
331 
332     KeyValue = Key->getValue(KeyStorage);
333     if (KeyValue.equals("source")) {
334       std::string Error;
335 
336       Source = Value->getValue(ValueStorage);
337       if (!Regex(Source).isValid(Error)) {
338         YS.printError(Field.getKey(), "invalid regex: " + Error);
339         return false;
340       }
341     } else if (KeyValue.equals("target")) {
342       Target = Value->getValue(ValueStorage);
343     } else if (KeyValue.equals("transform")) {
344       Transform = Value->getValue(ValueStorage);
345     } else if (KeyValue.equals("naked")) {
346       std::string Undecorated;
347 
348       Undecorated = Value->getValue(ValueStorage);
349       Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1";
350     } else {
351       YS.printError(Field.getKey(), "unknown key for function");
352       return false;
353     }
354   }
355 
356   if (Transform.empty() == Target.empty()) {
357     YS.printError(Descriptor,
358                   "exactly one of transform or target must be specified");
359     return false;
360   }
361 
362   // TODO see if there is a more elegant solution to selecting the rewrite
363   // descriptor type
364   if (!Target.empty())
365     DL->push_back(new ExplicitRewriteFunctionDescriptor(Source, Target, Naked));
366   else
367     DL->push_back(new PatternRewriteFunctionDescriptor(Source, Transform));
368 
369   return true;
370 }
371 
372 bool RewriteMapParser::
parseRewriteGlobalVariableDescriptor(yaml::Stream & YS,yaml::ScalarNode * K,yaml::MappingNode * Descriptor,RewriteDescriptorList * DL)373 parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
374                                      yaml::MappingNode *Descriptor,
375                                      RewriteDescriptorList *DL) {
376   std::string Source;
377   std::string Target;
378   std::string Transform;
379 
380   for (auto &Field : *Descriptor) {
381     yaml::ScalarNode *Key;
382     yaml::ScalarNode *Value;
383     SmallString<32> KeyStorage;
384     SmallString<32> ValueStorage;
385     StringRef KeyValue;
386 
387     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
388     if (!Key) {
389       YS.printError(Field.getKey(), "descriptor Key must be a scalar");
390       return false;
391     }
392 
393     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
394     if (!Value) {
395       YS.printError(Field.getValue(), "descriptor value must be a scalar");
396       return false;
397     }
398 
399     KeyValue = Key->getValue(KeyStorage);
400     if (KeyValue.equals("source")) {
401       std::string Error;
402 
403       Source = Value->getValue(ValueStorage);
404       if (!Regex(Source).isValid(Error)) {
405         YS.printError(Field.getKey(), "invalid regex: " + Error);
406         return false;
407       }
408     } else if (KeyValue.equals("target")) {
409       Target = Value->getValue(ValueStorage);
410     } else if (KeyValue.equals("transform")) {
411       Transform = Value->getValue(ValueStorage);
412     } else {
413       YS.printError(Field.getKey(), "unknown Key for Global Variable");
414       return false;
415     }
416   }
417 
418   if (Transform.empty() == Target.empty()) {
419     YS.printError(Descriptor,
420                   "exactly one of transform or target must be specified");
421     return false;
422   }
423 
424   if (!Target.empty())
425     DL->push_back(new ExplicitRewriteGlobalVariableDescriptor(Source, Target,
426                                                               /*Naked*/false));
427   else
428     DL->push_back(new PatternRewriteGlobalVariableDescriptor(Source,
429                                                              Transform));
430 
431   return true;
432 }
433 
434 bool RewriteMapParser::
parseRewriteGlobalAliasDescriptor(yaml::Stream & YS,yaml::ScalarNode * K,yaml::MappingNode * Descriptor,RewriteDescriptorList * DL)435 parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
436                                   yaml::MappingNode *Descriptor,
437                                   RewriteDescriptorList *DL) {
438   std::string Source;
439   std::string Target;
440   std::string Transform;
441 
442   for (auto &Field : *Descriptor) {
443     yaml::ScalarNode *Key;
444     yaml::ScalarNode *Value;
445     SmallString<32> KeyStorage;
446     SmallString<32> ValueStorage;
447     StringRef KeyValue;
448 
449     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
450     if (!Key) {
451       YS.printError(Field.getKey(), "descriptor key must be a scalar");
452       return false;
453     }
454 
455     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
456     if (!Value) {
457       YS.printError(Field.getValue(), "descriptor value must be a scalar");
458       return false;
459     }
460 
461     KeyValue = Key->getValue(KeyStorage);
462     if (KeyValue.equals("source")) {
463       std::string Error;
464 
465       Source = Value->getValue(ValueStorage);
466       if (!Regex(Source).isValid(Error)) {
467         YS.printError(Field.getKey(), "invalid regex: " + Error);
468         return false;
469       }
470     } else if (KeyValue.equals("target")) {
471       Target = Value->getValue(ValueStorage);
472     } else if (KeyValue.equals("transform")) {
473       Transform = Value->getValue(ValueStorage);
474     } else {
475       YS.printError(Field.getKey(), "unknown key for Global Alias");
476       return false;
477     }
478   }
479 
480   if (Transform.empty() == Target.empty()) {
481     YS.printError(Descriptor,
482                   "exactly one of transform or target must be specified");
483     return false;
484   }
485 
486   if (!Target.empty())
487     DL->push_back(new ExplicitRewriteNamedAliasDescriptor(Source, Target,
488                                                           /*Naked*/false));
489   else
490     DL->push_back(new PatternRewriteNamedAliasDescriptor(Source, Transform));
491 
492   return true;
493 }
494 
495 namespace {
496 class RewriteSymbols : public ModulePass {
497 public:
498   static char ID; // Pass identification, replacement for typeid
499 
500   RewriteSymbols();
501   RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL);
502 
503   bool runOnModule(Module &M) override;
504 
505 private:
506   void loadAndParseMapFiles();
507 
508   SymbolRewriter::RewriteDescriptorList Descriptors;
509 };
510 
511 char RewriteSymbols::ID = 0;
512 
RewriteSymbols()513 RewriteSymbols::RewriteSymbols() : ModulePass(ID) {
514   initializeRewriteSymbolsPass(*PassRegistry::getPassRegistry());
515   loadAndParseMapFiles();
516 }
517 
RewriteSymbols(SymbolRewriter::RewriteDescriptorList & DL)518 RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL)
519     : ModulePass(ID) {
520   Descriptors.splice(Descriptors.begin(), DL);
521 }
522 
runOnModule(Module & M)523 bool RewriteSymbols::runOnModule(Module &M) {
524   bool Changed;
525 
526   Changed = false;
527   for (auto &Descriptor : Descriptors)
528     Changed |= Descriptor.performOnModule(M);
529 
530   return Changed;
531 }
532 
loadAndParseMapFiles()533 void RewriteSymbols::loadAndParseMapFiles() {
534   const std::vector<std::string> MapFiles(RewriteMapFiles);
535   SymbolRewriter::RewriteMapParser parser;
536 
537   for (const auto &MapFile : MapFiles)
538     parser.parse(MapFile, &Descriptors);
539 }
540 }
541 
542 INITIALIZE_PASS(RewriteSymbols, "rewrite-symbols", "Rewrite Symbols", false,
543                 false)
544 
createRewriteSymbolsPass()545 ModulePass *llvm::createRewriteSymbolsPass() { return new RewriteSymbols(); }
546 
547 ModulePass *
createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList & DL)548 llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) {
549   return new RewriteSymbols(DL);
550 }
551