1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Instrumentation-based profile-guided optimization
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "CodeGenPGO.h"
15 #include "CodeGenFunction.h"
16 #include "CoverageMappingGen.h"
17 #include "clang/AST/RecursiveASTVisitor.h"
18 #include "clang/AST/StmtVisitor.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/MDBuilder.h"
21 #include "llvm/ProfileData/InstrProfReader.h"
22 #include "llvm/Support/Endian.h"
23 #include "llvm/Support/FileSystem.h"
24 #include "llvm/Support/MD5.h"
25
26 using namespace clang;
27 using namespace CodeGen;
28
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)29 void CodeGenPGO::setFuncName(StringRef Name,
30 llvm::GlobalValue::LinkageTypes Linkage) {
31 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
32 FuncName = llvm::getPGOFuncName(
33 Name, Linkage, CGM.getCodeGenOpts().MainFileName,
34 PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
35
36 // If we're generating a profile, create a variable for the name.
37 if (CGM.getCodeGenOpts().ProfileInstrGenerate)
38 FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
39 }
40
setFuncName(llvm::Function * Fn)41 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
42 setFuncName(Fn->getName(), Fn->getLinkage());
43 }
44
45 namespace {
46 /// \brief Stable hasher for PGO region counters.
47 ///
48 /// PGOHash produces a stable hash of a given function's control flow.
49 ///
50 /// Changing the output of this hash will invalidate all previously generated
51 /// profiles -- i.e., don't do it.
52 ///
53 /// \note When this hash does eventually change (years?), we still need to
54 /// support old hashes. We'll need to pull in the version number from the
55 /// profile data format and use the matching hash function.
56 class PGOHash {
57 uint64_t Working;
58 unsigned Count;
59 llvm::MD5 MD5;
60
61 static const int NumBitsPerType = 6;
62 static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
63 static const unsigned TooBig = 1u << NumBitsPerType;
64
65 public:
66 /// \brief Hash values for AST nodes.
67 ///
68 /// Distinct values for AST nodes that have region counters attached.
69 ///
70 /// These values must be stable. All new members must be added at the end,
71 /// and no members should be removed. Changing the enumeration value for an
72 /// AST node will affect the hash of every function that contains that node.
73 enum HashType : unsigned char {
74 None = 0,
75 LabelStmt = 1,
76 WhileStmt,
77 DoStmt,
78 ForStmt,
79 CXXForRangeStmt,
80 ObjCForCollectionStmt,
81 SwitchStmt,
82 CaseStmt,
83 DefaultStmt,
84 IfStmt,
85 CXXTryStmt,
86 CXXCatchStmt,
87 ConditionalOperator,
88 BinaryOperatorLAnd,
89 BinaryOperatorLOr,
90 BinaryConditionalOperator,
91
92 // Keep this last. It's for the static assert that follows.
93 LastHashType
94 };
95 static_assert(LastHashType <= TooBig, "Too many types in HashType");
96
97 // TODO: When this format changes, take in a version number here, and use the
98 // old hash calculation for file formats that used the old hash.
PGOHash()99 PGOHash() : Working(0), Count(0) {}
100 void combine(HashType Type);
101 uint64_t finalize();
102 };
103 const int PGOHash::NumBitsPerType;
104 const unsigned PGOHash::NumTypesPerWord;
105 const unsigned PGOHash::TooBig;
106
107 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
108 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
109 /// The next counter value to assign.
110 unsigned NextCounter;
111 /// The function hash.
112 PGOHash Hash;
113 /// The map of statements to counters.
114 llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
115
MapRegionCounters__anone18cd8db0111::MapRegionCounters116 MapRegionCounters(llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
117 : NextCounter(0), CounterMap(CounterMap) {}
118
119 // Blocks and lambdas are handled as separate functions, so we need not
120 // traverse them in the parent context.
TraverseBlockExpr__anone18cd8db0111::MapRegionCounters121 bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaBody__anone18cd8db0111::MapRegionCounters122 bool TraverseLambdaBody(LambdaExpr *LE) { return true; }
TraverseCapturedStmt__anone18cd8db0111::MapRegionCounters123 bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
124
VisitDecl__anone18cd8db0111::MapRegionCounters125 bool VisitDecl(const Decl *D) {
126 switch (D->getKind()) {
127 default:
128 break;
129 case Decl::Function:
130 case Decl::CXXMethod:
131 case Decl::CXXConstructor:
132 case Decl::CXXDestructor:
133 case Decl::CXXConversion:
134 case Decl::ObjCMethod:
135 case Decl::Block:
136 case Decl::Captured:
137 CounterMap[D->getBody()] = NextCounter++;
138 break;
139 }
140 return true;
141 }
142
VisitStmt__anone18cd8db0111::MapRegionCounters143 bool VisitStmt(const Stmt *S) {
144 auto Type = getHashType(S);
145 if (Type == PGOHash::None)
146 return true;
147
148 CounterMap[S] = NextCounter++;
149 Hash.combine(Type);
150 return true;
151 }
getHashType__anone18cd8db0111::MapRegionCounters152 PGOHash::HashType getHashType(const Stmt *S) {
153 switch (S->getStmtClass()) {
154 default:
155 break;
156 case Stmt::LabelStmtClass:
157 return PGOHash::LabelStmt;
158 case Stmt::WhileStmtClass:
159 return PGOHash::WhileStmt;
160 case Stmt::DoStmtClass:
161 return PGOHash::DoStmt;
162 case Stmt::ForStmtClass:
163 return PGOHash::ForStmt;
164 case Stmt::CXXForRangeStmtClass:
165 return PGOHash::CXXForRangeStmt;
166 case Stmt::ObjCForCollectionStmtClass:
167 return PGOHash::ObjCForCollectionStmt;
168 case Stmt::SwitchStmtClass:
169 return PGOHash::SwitchStmt;
170 case Stmt::CaseStmtClass:
171 return PGOHash::CaseStmt;
172 case Stmt::DefaultStmtClass:
173 return PGOHash::DefaultStmt;
174 case Stmt::IfStmtClass:
175 return PGOHash::IfStmt;
176 case Stmt::CXXTryStmtClass:
177 return PGOHash::CXXTryStmt;
178 case Stmt::CXXCatchStmtClass:
179 return PGOHash::CXXCatchStmt;
180 case Stmt::ConditionalOperatorClass:
181 return PGOHash::ConditionalOperator;
182 case Stmt::BinaryConditionalOperatorClass:
183 return PGOHash::BinaryConditionalOperator;
184 case Stmt::BinaryOperatorClass: {
185 const BinaryOperator *BO = cast<BinaryOperator>(S);
186 if (BO->getOpcode() == BO_LAnd)
187 return PGOHash::BinaryOperatorLAnd;
188 if (BO->getOpcode() == BO_LOr)
189 return PGOHash::BinaryOperatorLOr;
190 break;
191 }
192 }
193 return PGOHash::None;
194 }
195 };
196
197 /// A StmtVisitor that propagates the raw counts through the AST and
198 /// records the count at statements where the value may change.
199 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
200 /// PGO state.
201 CodeGenPGO &PGO;
202
203 /// A flag that is set when the current count should be recorded on the
204 /// next statement, such as at the exit of a loop.
205 bool RecordNextStmtCount;
206
207 /// The count at the current location in the traversal.
208 uint64_t CurrentCount;
209
210 /// The map of statements to count values.
211 llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
212
213 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
214 struct BreakContinue {
215 uint64_t BreakCount;
216 uint64_t ContinueCount;
BreakContinue__anone18cd8db0111::ComputeRegionCounts::BreakContinue217 BreakContinue() : BreakCount(0), ContinueCount(0) {}
218 };
219 SmallVector<BreakContinue, 8> BreakContinueStack;
220
ComputeRegionCounts__anone18cd8db0111::ComputeRegionCounts221 ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
222 CodeGenPGO &PGO)
223 : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
224
RecordStmtCount__anone18cd8db0111::ComputeRegionCounts225 void RecordStmtCount(const Stmt *S) {
226 if (RecordNextStmtCount) {
227 CountMap[S] = CurrentCount;
228 RecordNextStmtCount = false;
229 }
230 }
231
232 /// Set and return the current count.
setCount__anone18cd8db0111::ComputeRegionCounts233 uint64_t setCount(uint64_t Count) {
234 CurrentCount = Count;
235 return Count;
236 }
237
VisitStmt__anone18cd8db0111::ComputeRegionCounts238 void VisitStmt(const Stmt *S) {
239 RecordStmtCount(S);
240 for (const Stmt *Child : S->children())
241 if (Child)
242 this->Visit(Child);
243 }
244
VisitFunctionDecl__anone18cd8db0111::ComputeRegionCounts245 void VisitFunctionDecl(const FunctionDecl *D) {
246 // Counter tracks entry to the function body.
247 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
248 CountMap[D->getBody()] = BodyCount;
249 Visit(D->getBody());
250 }
251
252 // Skip lambda expressions. We visit these as FunctionDecls when we're
253 // generating them and aren't interested in the body when generating a
254 // parent context.
VisitLambdaExpr__anone18cd8db0111::ComputeRegionCounts255 void VisitLambdaExpr(const LambdaExpr *LE) {}
256
VisitCapturedDecl__anone18cd8db0111::ComputeRegionCounts257 void VisitCapturedDecl(const CapturedDecl *D) {
258 // Counter tracks entry to the capture body.
259 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
260 CountMap[D->getBody()] = BodyCount;
261 Visit(D->getBody());
262 }
263
VisitObjCMethodDecl__anone18cd8db0111::ComputeRegionCounts264 void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
265 // Counter tracks entry to the method body.
266 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
267 CountMap[D->getBody()] = BodyCount;
268 Visit(D->getBody());
269 }
270
VisitBlockDecl__anone18cd8db0111::ComputeRegionCounts271 void VisitBlockDecl(const BlockDecl *D) {
272 // Counter tracks entry to the block body.
273 uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
274 CountMap[D->getBody()] = BodyCount;
275 Visit(D->getBody());
276 }
277
VisitReturnStmt__anone18cd8db0111::ComputeRegionCounts278 void VisitReturnStmt(const ReturnStmt *S) {
279 RecordStmtCount(S);
280 if (S->getRetValue())
281 Visit(S->getRetValue());
282 CurrentCount = 0;
283 RecordNextStmtCount = true;
284 }
285
VisitCXXThrowExpr__anone18cd8db0111::ComputeRegionCounts286 void VisitCXXThrowExpr(const CXXThrowExpr *E) {
287 RecordStmtCount(E);
288 if (E->getSubExpr())
289 Visit(E->getSubExpr());
290 CurrentCount = 0;
291 RecordNextStmtCount = true;
292 }
293
VisitGotoStmt__anone18cd8db0111::ComputeRegionCounts294 void VisitGotoStmt(const GotoStmt *S) {
295 RecordStmtCount(S);
296 CurrentCount = 0;
297 RecordNextStmtCount = true;
298 }
299
VisitLabelStmt__anone18cd8db0111::ComputeRegionCounts300 void VisitLabelStmt(const LabelStmt *S) {
301 RecordNextStmtCount = false;
302 // Counter tracks the block following the label.
303 uint64_t BlockCount = setCount(PGO.getRegionCount(S));
304 CountMap[S] = BlockCount;
305 Visit(S->getSubStmt());
306 }
307
VisitBreakStmt__anone18cd8db0111::ComputeRegionCounts308 void VisitBreakStmt(const BreakStmt *S) {
309 RecordStmtCount(S);
310 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
311 BreakContinueStack.back().BreakCount += CurrentCount;
312 CurrentCount = 0;
313 RecordNextStmtCount = true;
314 }
315
VisitContinueStmt__anone18cd8db0111::ComputeRegionCounts316 void VisitContinueStmt(const ContinueStmt *S) {
317 RecordStmtCount(S);
318 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
319 BreakContinueStack.back().ContinueCount += CurrentCount;
320 CurrentCount = 0;
321 RecordNextStmtCount = true;
322 }
323
VisitWhileStmt__anone18cd8db0111::ComputeRegionCounts324 void VisitWhileStmt(const WhileStmt *S) {
325 RecordStmtCount(S);
326 uint64_t ParentCount = CurrentCount;
327
328 BreakContinueStack.push_back(BreakContinue());
329 // Visit the body region first so the break/continue adjustments can be
330 // included when visiting the condition.
331 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
332 CountMap[S->getBody()] = CurrentCount;
333 Visit(S->getBody());
334 uint64_t BackedgeCount = CurrentCount;
335
336 // ...then go back and propagate counts through the condition. The count
337 // at the start of the condition is the sum of the incoming edges,
338 // the backedge from the end of the loop body, and the edges from
339 // continue statements.
340 BreakContinue BC = BreakContinueStack.pop_back_val();
341 uint64_t CondCount =
342 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
343 CountMap[S->getCond()] = CondCount;
344 Visit(S->getCond());
345 setCount(BC.BreakCount + CondCount - BodyCount);
346 RecordNextStmtCount = true;
347 }
348
VisitDoStmt__anone18cd8db0111::ComputeRegionCounts349 void VisitDoStmt(const DoStmt *S) {
350 RecordStmtCount(S);
351 uint64_t LoopCount = PGO.getRegionCount(S);
352
353 BreakContinueStack.push_back(BreakContinue());
354 // The count doesn't include the fallthrough from the parent scope. Add it.
355 uint64_t BodyCount = setCount(LoopCount + CurrentCount);
356 CountMap[S->getBody()] = BodyCount;
357 Visit(S->getBody());
358 uint64_t BackedgeCount = CurrentCount;
359
360 BreakContinue BC = BreakContinueStack.pop_back_val();
361 // The count at the start of the condition is equal to the count at the
362 // end of the body, plus any continues.
363 uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
364 CountMap[S->getCond()] = CondCount;
365 Visit(S->getCond());
366 setCount(BC.BreakCount + CondCount - LoopCount);
367 RecordNextStmtCount = true;
368 }
369
VisitForStmt__anone18cd8db0111::ComputeRegionCounts370 void VisitForStmt(const ForStmt *S) {
371 RecordStmtCount(S);
372 if (S->getInit())
373 Visit(S->getInit());
374
375 uint64_t ParentCount = CurrentCount;
376
377 BreakContinueStack.push_back(BreakContinue());
378 // Visit the body region first. (This is basically the same as a while
379 // loop; see further comments in VisitWhileStmt.)
380 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
381 CountMap[S->getBody()] = BodyCount;
382 Visit(S->getBody());
383 uint64_t BackedgeCount = CurrentCount;
384 BreakContinue BC = BreakContinueStack.pop_back_val();
385
386 // The increment is essentially part of the body but it needs to include
387 // the count for all the continue statements.
388 if (S->getInc()) {
389 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
390 CountMap[S->getInc()] = IncCount;
391 Visit(S->getInc());
392 }
393
394 // ...then go back and propagate counts through the condition.
395 uint64_t CondCount =
396 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
397 if (S->getCond()) {
398 CountMap[S->getCond()] = CondCount;
399 Visit(S->getCond());
400 }
401 setCount(BC.BreakCount + CondCount - BodyCount);
402 RecordNextStmtCount = true;
403 }
404
VisitCXXForRangeStmt__anone18cd8db0111::ComputeRegionCounts405 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
406 RecordStmtCount(S);
407 Visit(S->getLoopVarStmt());
408 Visit(S->getRangeStmt());
409 Visit(S->getBeginEndStmt());
410
411 uint64_t ParentCount = CurrentCount;
412 BreakContinueStack.push_back(BreakContinue());
413 // Visit the body region first. (This is basically the same as a while
414 // loop; see further comments in VisitWhileStmt.)
415 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
416 CountMap[S->getBody()] = BodyCount;
417 Visit(S->getBody());
418 uint64_t BackedgeCount = CurrentCount;
419 BreakContinue BC = BreakContinueStack.pop_back_val();
420
421 // The increment is essentially part of the body but it needs to include
422 // the count for all the continue statements.
423 uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
424 CountMap[S->getInc()] = IncCount;
425 Visit(S->getInc());
426
427 // ...then go back and propagate counts through the condition.
428 uint64_t CondCount =
429 setCount(ParentCount + BackedgeCount + BC.ContinueCount);
430 CountMap[S->getCond()] = CondCount;
431 Visit(S->getCond());
432 setCount(BC.BreakCount + CondCount - BodyCount);
433 RecordNextStmtCount = true;
434 }
435
VisitObjCForCollectionStmt__anone18cd8db0111::ComputeRegionCounts436 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
437 RecordStmtCount(S);
438 Visit(S->getElement());
439 uint64_t ParentCount = CurrentCount;
440 BreakContinueStack.push_back(BreakContinue());
441 // Counter tracks the body of the loop.
442 uint64_t BodyCount = setCount(PGO.getRegionCount(S));
443 CountMap[S->getBody()] = BodyCount;
444 Visit(S->getBody());
445 uint64_t BackedgeCount = CurrentCount;
446 BreakContinue BC = BreakContinueStack.pop_back_val();
447
448 setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
449 BodyCount);
450 RecordNextStmtCount = true;
451 }
452
VisitSwitchStmt__anone18cd8db0111::ComputeRegionCounts453 void VisitSwitchStmt(const SwitchStmt *S) {
454 RecordStmtCount(S);
455 Visit(S->getCond());
456 CurrentCount = 0;
457 BreakContinueStack.push_back(BreakContinue());
458 Visit(S->getBody());
459 // If the switch is inside a loop, add the continue counts.
460 BreakContinue BC = BreakContinueStack.pop_back_val();
461 if (!BreakContinueStack.empty())
462 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
463 // Counter tracks the exit block of the switch.
464 setCount(PGO.getRegionCount(S));
465 RecordNextStmtCount = true;
466 }
467
VisitSwitchCase__anone18cd8db0111::ComputeRegionCounts468 void VisitSwitchCase(const SwitchCase *S) {
469 RecordNextStmtCount = false;
470 // Counter for this particular case. This counts only jumps from the
471 // switch header and does not include fallthrough from the case before
472 // this one.
473 uint64_t CaseCount = PGO.getRegionCount(S);
474 setCount(CurrentCount + CaseCount);
475 // We need the count without fallthrough in the mapping, so it's more useful
476 // for branch probabilities.
477 CountMap[S] = CaseCount;
478 RecordNextStmtCount = true;
479 Visit(S->getSubStmt());
480 }
481
VisitIfStmt__anone18cd8db0111::ComputeRegionCounts482 void VisitIfStmt(const IfStmt *S) {
483 RecordStmtCount(S);
484 uint64_t ParentCount = CurrentCount;
485 Visit(S->getCond());
486
487 // Counter tracks the "then" part of an if statement. The count for
488 // the "else" part, if it exists, will be calculated from this counter.
489 uint64_t ThenCount = setCount(PGO.getRegionCount(S));
490 CountMap[S->getThen()] = ThenCount;
491 Visit(S->getThen());
492 uint64_t OutCount = CurrentCount;
493
494 uint64_t ElseCount = ParentCount - ThenCount;
495 if (S->getElse()) {
496 setCount(ElseCount);
497 CountMap[S->getElse()] = ElseCount;
498 Visit(S->getElse());
499 OutCount += CurrentCount;
500 } else
501 OutCount += ElseCount;
502 setCount(OutCount);
503 RecordNextStmtCount = true;
504 }
505
VisitCXXTryStmt__anone18cd8db0111::ComputeRegionCounts506 void VisitCXXTryStmt(const CXXTryStmt *S) {
507 RecordStmtCount(S);
508 Visit(S->getTryBlock());
509 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
510 Visit(S->getHandler(I));
511 // Counter tracks the continuation block of the try statement.
512 setCount(PGO.getRegionCount(S));
513 RecordNextStmtCount = true;
514 }
515
VisitCXXCatchStmt__anone18cd8db0111::ComputeRegionCounts516 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
517 RecordNextStmtCount = false;
518 // Counter tracks the catch statement's handler block.
519 uint64_t CatchCount = setCount(PGO.getRegionCount(S));
520 CountMap[S] = CatchCount;
521 Visit(S->getHandlerBlock());
522 }
523
VisitAbstractConditionalOperator__anone18cd8db0111::ComputeRegionCounts524 void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
525 RecordStmtCount(E);
526 uint64_t ParentCount = CurrentCount;
527 Visit(E->getCond());
528
529 // Counter tracks the "true" part of a conditional operator. The
530 // count in the "false" part will be calculated from this counter.
531 uint64_t TrueCount = setCount(PGO.getRegionCount(E));
532 CountMap[E->getTrueExpr()] = TrueCount;
533 Visit(E->getTrueExpr());
534 uint64_t OutCount = CurrentCount;
535
536 uint64_t FalseCount = setCount(ParentCount - TrueCount);
537 CountMap[E->getFalseExpr()] = FalseCount;
538 Visit(E->getFalseExpr());
539 OutCount += CurrentCount;
540
541 setCount(OutCount);
542 RecordNextStmtCount = true;
543 }
544
VisitBinLAnd__anone18cd8db0111::ComputeRegionCounts545 void VisitBinLAnd(const BinaryOperator *E) {
546 RecordStmtCount(E);
547 uint64_t ParentCount = CurrentCount;
548 Visit(E->getLHS());
549 // Counter tracks the right hand side of a logical and operator.
550 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
551 CountMap[E->getRHS()] = RHSCount;
552 Visit(E->getRHS());
553 setCount(ParentCount + RHSCount - CurrentCount);
554 RecordNextStmtCount = true;
555 }
556
VisitBinLOr__anone18cd8db0111::ComputeRegionCounts557 void VisitBinLOr(const BinaryOperator *E) {
558 RecordStmtCount(E);
559 uint64_t ParentCount = CurrentCount;
560 Visit(E->getLHS());
561 // Counter tracks the right hand side of a logical or operator.
562 uint64_t RHSCount = setCount(PGO.getRegionCount(E));
563 CountMap[E->getRHS()] = RHSCount;
564 Visit(E->getRHS());
565 setCount(ParentCount + RHSCount - CurrentCount);
566 RecordNextStmtCount = true;
567 }
568 };
569 } // end anonymous namespace
570
combine(HashType Type)571 void PGOHash::combine(HashType Type) {
572 // Check that we never combine 0 and only have six bits.
573 assert(Type && "Hash is invalid: unexpected type 0");
574 assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
575
576 // Pass through MD5 if enough work has built up.
577 if (Count && Count % NumTypesPerWord == 0) {
578 using namespace llvm::support;
579 uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
580 MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
581 Working = 0;
582 }
583
584 // Accumulate the current type.
585 ++Count;
586 Working = Working << NumBitsPerType | Type;
587 }
588
finalize()589 uint64_t PGOHash::finalize() {
590 // Use Working as the hash directly if we never used MD5.
591 if (Count <= NumTypesPerWord)
592 // No need to byte swap here, since none of the math was endian-dependent.
593 // This number will be byte-swapped as required on endianness transitions,
594 // so we will see the same value on the other side.
595 return Working;
596
597 // Check for remaining work in Working.
598 if (Working)
599 MD5.update(Working);
600
601 // Finalize the MD5 and return the hash.
602 llvm::MD5::MD5Result Result;
603 MD5.final(Result);
604 using namespace llvm::support;
605 return endian::read<uint64_t, little, unaligned>(Result);
606 }
607
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)608 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
609 const Decl *D = GD.getDecl();
610 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
611 llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
612 if (!InstrumentRegions && !PGOReader)
613 return;
614 if (D->isImplicit())
615 return;
616 // Constructors and destructors may be represented by several functions in IR.
617 // If so, instrument only base variant, others are implemented by delegation
618 // to the base one, it would be counted twice otherwise.
619 if (CGM.getTarget().getCXXABI().hasConstructorVariants() &&
620 ((isa<CXXConstructorDecl>(GD.getDecl()) &&
621 GD.getCtorType() != Ctor_Base) ||
622 (isa<CXXDestructorDecl>(GD.getDecl()) &&
623 GD.getDtorType() != Dtor_Base))) {
624 return;
625 }
626 CGM.ClearUnusedCoverageMapping(D);
627 setFuncName(Fn);
628
629 mapRegionCounters(D);
630 if (CGM.getCodeGenOpts().CoverageMapping)
631 emitCounterRegionMapping(D);
632 if (PGOReader) {
633 SourceManager &SM = CGM.getContext().getSourceManager();
634 loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
635 computeRegionCounts(D);
636 applyFunctionAttributes(PGOReader, Fn);
637 }
638 }
639
mapRegionCounters(const Decl * D)640 void CodeGenPGO::mapRegionCounters(const Decl *D) {
641 RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
642 MapRegionCounters Walker(*RegionCounterMap);
643 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
644 Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
645 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
646 Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
647 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
648 Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
649 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
650 Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
651 assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
652 NumRegionCounters = Walker.NextCounter;
653 FunctionHash = Walker.Hash.finalize();
654 }
655
emitCounterRegionMapping(const Decl * D)656 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
657 if (SkipCoverageMapping)
658 return;
659 // Don't map the functions inside the system headers
660 auto Loc = D->getBody()->getLocStart();
661 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
662 return;
663
664 std::string CoverageMapping;
665 llvm::raw_string_ostream OS(CoverageMapping);
666 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
667 CGM.getContext().getSourceManager(),
668 CGM.getLangOpts(), RegionCounterMap.get());
669 MappingGen.emitCounterMapping(D, OS);
670 OS.flush();
671
672 if (CoverageMapping.empty())
673 return;
674
675 CGM.getCoverageMapping()->addFunctionMappingRecord(
676 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
677 }
678
679 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)680 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
681 llvm::GlobalValue::LinkageTypes Linkage) {
682 if (SkipCoverageMapping)
683 return;
684 // Don't map the functions inside the system headers
685 auto Loc = D->getBody()->getLocStart();
686 if (CGM.getContext().getSourceManager().isInSystemHeader(Loc))
687 return;
688
689 std::string CoverageMapping;
690 llvm::raw_string_ostream OS(CoverageMapping);
691 CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
692 CGM.getContext().getSourceManager(),
693 CGM.getLangOpts());
694 MappingGen.emitEmptyMapping(D, OS);
695 OS.flush();
696
697 if (CoverageMapping.empty())
698 return;
699
700 setFuncName(Name, Linkage);
701 CGM.getCoverageMapping()->addFunctionMappingRecord(
702 FuncNameVar, FuncName, FunctionHash, CoverageMapping);
703 }
704
computeRegionCounts(const Decl * D)705 void CodeGenPGO::computeRegionCounts(const Decl *D) {
706 StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
707 ComputeRegionCounts Walker(*StmtCountMap, *this);
708 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
709 Walker.VisitFunctionDecl(FD);
710 else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
711 Walker.VisitObjCMethodDecl(MD);
712 else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
713 Walker.VisitBlockDecl(BD);
714 else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
715 Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
716 }
717
718 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)719 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
720 llvm::Function *Fn) {
721 if (!haveRegionCounts())
722 return;
723
724 uint64_t MaxFunctionCount = PGOReader->getMaximumFunctionCount();
725 uint64_t FunctionCount = getRegionCount(nullptr);
726 if (FunctionCount >= (uint64_t)(0.3 * (double)MaxFunctionCount))
727 // Turn on InlineHint attribute for hot functions.
728 // FIXME: 30% is from preliminary tuning on SPEC, it may not be optimal.
729 Fn->addFnAttr(llvm::Attribute::InlineHint);
730 else if (FunctionCount <= (uint64_t)(0.01 * (double)MaxFunctionCount))
731 // Turn on Cold attribute for cold functions.
732 // FIXME: 1% is from preliminary tuning on SPEC, it may not be optimal.
733 Fn->addFnAttr(llvm::Attribute::Cold);
734
735 Fn->setEntryCount(FunctionCount);
736 }
737
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S)738 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S) {
739 if (!CGM.getCodeGenOpts().ProfileInstrGenerate || !RegionCounterMap)
740 return;
741 if (!Builder.GetInsertBlock())
742 return;
743
744 unsigned Counter = (*RegionCounterMap)[S];
745 auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
746 Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
747 {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
748 Builder.getInt64(FunctionHash),
749 Builder.getInt32(NumRegionCounters),
750 Builder.getInt32(Counter)});
751 }
752
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)753 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
754 bool IsInMainFile) {
755 CGM.getPGOStats().addVisited(IsInMainFile);
756 RegionCounts.clear();
757 if (std::error_code EC =
758 PGOReader->getFunctionCounts(FuncName, FunctionHash, RegionCounts)) {
759 if (EC == llvm::instrprof_error::unknown_function)
760 CGM.getPGOStats().addMissing(IsInMainFile);
761 else if (EC == llvm::instrprof_error::hash_mismatch)
762 CGM.getPGOStats().addMismatched(IsInMainFile);
763 else if (EC == llvm::instrprof_error::malformed)
764 // TODO: Consider a more specific warning for this case.
765 CGM.getPGOStats().addMismatched(IsInMainFile);
766 RegionCounts.clear();
767 }
768 }
769
770 /// \brief Calculate what to divide by to scale weights.
771 ///
772 /// Given the maximum weight, calculate a divisor that will scale all the
773 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)774 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
775 return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
776 }
777
778 /// \brief Scale an individual branch weight (and add 1).
779 ///
780 /// Scale a 64-bit weight down to 32-bits using \c Scale.
781 ///
782 /// According to Laplace's Rule of Succession, it is better to compute the
783 /// weight based on the count plus 1, so universally add 1 to the value.
784 ///
785 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
786 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)787 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
788 assert(Scale && "scale by 0?");
789 uint64_t Scaled = Weight / Scale + 1;
790 assert(Scaled <= UINT32_MAX && "overflow 32-bits");
791 return Scaled;
792 }
793
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount)794 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
795 uint64_t FalseCount) {
796 // Check for empty weights.
797 if (!TrueCount && !FalseCount)
798 return nullptr;
799
800 // Calculate how to scale down to 32-bits.
801 uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
802
803 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
804 return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
805 scaleBranchWeight(FalseCount, Scale));
806 }
807
808 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights)809 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) {
810 // We need at least two elements to create meaningful weights.
811 if (Weights.size() < 2)
812 return nullptr;
813
814 // Check for empty weights.
815 uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
816 if (MaxWeight == 0)
817 return nullptr;
818
819 // Calculate how to scale down to 32-bits.
820 uint64_t Scale = calculateWeightScale(MaxWeight);
821
822 SmallVector<uint32_t, 16> ScaledWeights;
823 ScaledWeights.reserve(Weights.size());
824 for (uint64_t W : Weights)
825 ScaledWeights.push_back(scaleBranchWeight(W, Scale));
826
827 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
828 return MDHelper.createBranchWeights(ScaledWeights);
829 }
830
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount)831 llvm::MDNode *CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
832 uint64_t LoopCount) {
833 if (!PGO.haveRegionCounts())
834 return nullptr;
835 Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
836 assert(CondCount.hasValue() && "missing expected loop condition count");
837 if (*CondCount == 0)
838 return nullptr;
839 return createProfileWeights(LoopCount,
840 std::max(*CondCount, LoopCount) - LoopCount);
841 }
842