1 //===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Instrumentation-based profile-guided optimization
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodeGenPGO.h"
14 #include "CodeGenFunction.h"
15 #include "CoverageMappingGen.h"
16 #include "clang/AST/RecursiveASTVisitor.h"
17 #include "clang/AST/StmtVisitor.h"
18 #include "llvm/IR/Intrinsics.h"
19 #include "llvm/IR/MDBuilder.h"
20 #include "llvm/Support/CommandLine.h"
21 #include "llvm/Support/Endian.h"
22 #include "llvm/Support/FileSystem.h"
23 #include "llvm/Support/MD5.h"
24 
25 static llvm::cl::opt<bool>
26     EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                          llvm::cl::desc("Enable value profiling"),
28                          llvm::cl::Hidden, llvm::cl::init(false));
29 
30 using namespace clang;
31 using namespace CodeGen;
32 
setFuncName(StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)33 void CodeGenPGO::setFuncName(StringRef Name,
34                              llvm::GlobalValue::LinkageTypes Linkage) {
35   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36   FuncName = llvm::getPGOFuncName(
37       Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38       PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39 
40   // If we're generating a profile, create a variable for the name.
41   if (CGM.getCodeGenOpts().hasProfileClangInstr())
42     FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43 }
44 
setFuncName(llvm::Function * Fn)45 void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46   setFuncName(Fn->getName(), Fn->getLinkage());
47   // Create PGOFuncName meta data.
48   llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49 }
50 
51 /// The version of the PGO hash algorithm.
52 enum PGOHashVersion : unsigned {
53   PGO_HASH_V1,
54   PGO_HASH_V2,
55   PGO_HASH_V3,
56 
57   // Keep this set to the latest hash version.
58   PGO_HASH_LATEST = PGO_HASH_V3
59 };
60 
61 namespace {
62 /// Stable hasher for PGO region counters.
63 ///
64 /// PGOHash produces a stable hash of a given function's control flow.
65 ///
66 /// Changing the output of this hash will invalidate all previously generated
67 /// profiles -- i.e., don't do it.
68 ///
69 /// \note  When this hash does eventually change (years?), we still need to
70 /// support old hashes.  We'll need to pull in the version number from the
71 /// profile data format and use the matching hash function.
72 class PGOHash {
73   uint64_t Working;
74   unsigned Count;
75   PGOHashVersion HashVersion;
76   llvm::MD5 MD5;
77 
78   static const int NumBitsPerType = 6;
79   static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80   static const unsigned TooBig = 1u << NumBitsPerType;
81 
82 public:
83   /// Hash values for AST nodes.
84   ///
85   /// Distinct values for AST nodes that have region counters attached.
86   ///
87   /// These values must be stable.  All new members must be added at the end,
88   /// and no members should be removed.  Changing the enumeration value for an
89   /// AST node will affect the hash of every function that contains that node.
90   enum HashType : unsigned char {
91     None = 0,
92     LabelStmt = 1,
93     WhileStmt,
94     DoStmt,
95     ForStmt,
96     CXXForRangeStmt,
97     ObjCForCollectionStmt,
98     SwitchStmt,
99     CaseStmt,
100     DefaultStmt,
101     IfStmt,
102     CXXTryStmt,
103     CXXCatchStmt,
104     ConditionalOperator,
105     BinaryOperatorLAnd,
106     BinaryOperatorLOr,
107     BinaryConditionalOperator,
108     // The preceding values are available with PGO_HASH_V1.
109 
110     EndOfScope,
111     IfThenBranch,
112     IfElseBranch,
113     GotoStmt,
114     IndirectGotoStmt,
115     BreakStmt,
116     ContinueStmt,
117     ReturnStmt,
118     ThrowExpr,
119     UnaryOperatorLNot,
120     BinaryOperatorLT,
121     BinaryOperatorGT,
122     BinaryOperatorLE,
123     BinaryOperatorGE,
124     BinaryOperatorEQ,
125     BinaryOperatorNE,
126     // The preceding values are available since PGO_HASH_V2.
127 
128     // Keep this last.  It's for the static assert that follows.
129     LastHashType
130   };
131   static_assert(LastHashType <= TooBig, "Too many types in HashType");
132 
PGOHash(PGOHashVersion HashVersion)133   PGOHash(PGOHashVersion HashVersion)
134       : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
135   void combine(HashType Type);
136   uint64_t finalize();
getHashVersion() const137   PGOHashVersion getHashVersion() const { return HashVersion; }
138 };
139 const int PGOHash::NumBitsPerType;
140 const unsigned PGOHash::NumTypesPerWord;
141 const unsigned PGOHash::TooBig;
142 
143 /// Get the PGO hash version used in the given indexed profile.
getPGOHashVersion(llvm::IndexedInstrProfReader * PGOReader,CodeGenModule & CGM)144 static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145                                         CodeGenModule &CGM) {
146   if (PGOReader->getVersion() <= 4)
147     return PGO_HASH_V1;
148   if (PGOReader->getVersion() <= 5)
149     return PGO_HASH_V2;
150   return PGO_HASH_V3;
151 }
152 
153 /// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154 struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155   using Base = RecursiveASTVisitor<MapRegionCounters>;
156 
157   /// The next counter value to assign.
158   unsigned NextCounter;
159   /// The function hash.
160   PGOHash Hash;
161   /// The map of statements to counters.
162   llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163 
MapRegionCounters__anon90a142290111::MapRegionCounters164   MapRegionCounters(PGOHashVersion HashVersion,
165                     llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
166       : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap) {}
167 
168   // Blocks and lambdas are handled as separate functions, so we need not
169   // traverse them in the parent context.
TraverseBlockExpr__anon90a142290111::MapRegionCounters170   bool TraverseBlockExpr(BlockExpr *BE) { return true; }
TraverseLambdaExpr__anon90a142290111::MapRegionCounters171   bool TraverseLambdaExpr(LambdaExpr *LE) {
172     // Traverse the captures, but not the body.
173     for (auto C : zip(LE->captures(), LE->capture_inits()))
174       TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
175     return true;
176   }
TraverseCapturedStmt__anon90a142290111::MapRegionCounters177   bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
178 
VisitDecl__anon90a142290111::MapRegionCounters179   bool VisitDecl(const Decl *D) {
180     switch (D->getKind()) {
181     default:
182       break;
183     case Decl::Function:
184     case Decl::CXXMethod:
185     case Decl::CXXConstructor:
186     case Decl::CXXDestructor:
187     case Decl::CXXConversion:
188     case Decl::ObjCMethod:
189     case Decl::Block:
190     case Decl::Captured:
191       CounterMap[D->getBody()] = NextCounter++;
192       break;
193     }
194     return true;
195   }
196 
197   /// If \p S gets a fresh counter, update the counter mappings. Return the
198   /// V1 hash of \p S.
updateCounterMappings__anon90a142290111::MapRegionCounters199   PGOHash::HashType updateCounterMappings(Stmt *S) {
200     auto Type = getHashType(PGO_HASH_V1, S);
201     if (Type != PGOHash::None)
202       CounterMap[S] = NextCounter++;
203     return Type;
204   }
205 
206   /// Include \p S in the function hash.
VisitStmt__anon90a142290111::MapRegionCounters207   bool VisitStmt(Stmt *S) {
208     auto Type = updateCounterMappings(S);
209     if (Hash.getHashVersion() != PGO_HASH_V1)
210       Type = getHashType(Hash.getHashVersion(), S);
211     if (Type != PGOHash::None)
212       Hash.combine(Type);
213     return true;
214   }
215 
TraverseIfStmt__anon90a142290111::MapRegionCounters216   bool TraverseIfStmt(IfStmt *If) {
217     // If we used the V1 hash, use the default traversal.
218     if (Hash.getHashVersion() == PGO_HASH_V1)
219       return Base::TraverseIfStmt(If);
220 
221     // Otherwise, keep track of which branch we're in while traversing.
222     VisitStmt(If);
223     for (Stmt *CS : If->children()) {
224       if (!CS)
225         continue;
226       if (CS == If->getThen())
227         Hash.combine(PGOHash::IfThenBranch);
228       else if (CS == If->getElse())
229         Hash.combine(PGOHash::IfElseBranch);
230       TraverseStmt(CS);
231     }
232     Hash.combine(PGOHash::EndOfScope);
233     return true;
234   }
235 
236 // If the statement type \p N is nestable, and its nesting impacts profile
237 // stability, define a custom traversal which tracks the end of the statement
238 // in the hash (provided we're not using the V1 hash).
239 #define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
240   bool Traverse##N(N *S) {                                                     \
241     Base::Traverse##N(S);                                                      \
242     if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
243       Hash.combine(PGOHash::EndOfScope);                                       \
244     return true;                                                               \
245   }
246 
247   DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
DEFINE_NESTABLE_TRAVERSAL__anon90a142290111::MapRegionCounters248   DEFINE_NESTABLE_TRAVERSAL(DoStmt)
249   DEFINE_NESTABLE_TRAVERSAL(ForStmt)
250   DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
251   DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
252   DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
253   DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
254 
255   /// Get version \p HashVersion of the PGO hash for \p S.
256   PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
257     switch (S->getStmtClass()) {
258     default:
259       break;
260     case Stmt::LabelStmtClass:
261       return PGOHash::LabelStmt;
262     case Stmt::WhileStmtClass:
263       return PGOHash::WhileStmt;
264     case Stmt::DoStmtClass:
265       return PGOHash::DoStmt;
266     case Stmt::ForStmtClass:
267       return PGOHash::ForStmt;
268     case Stmt::CXXForRangeStmtClass:
269       return PGOHash::CXXForRangeStmt;
270     case Stmt::ObjCForCollectionStmtClass:
271       return PGOHash::ObjCForCollectionStmt;
272     case Stmt::SwitchStmtClass:
273       return PGOHash::SwitchStmt;
274     case Stmt::CaseStmtClass:
275       return PGOHash::CaseStmt;
276     case Stmt::DefaultStmtClass:
277       return PGOHash::DefaultStmt;
278     case Stmt::IfStmtClass:
279       return PGOHash::IfStmt;
280     case Stmt::CXXTryStmtClass:
281       return PGOHash::CXXTryStmt;
282     case Stmt::CXXCatchStmtClass:
283       return PGOHash::CXXCatchStmt;
284     case Stmt::ConditionalOperatorClass:
285       return PGOHash::ConditionalOperator;
286     case Stmt::BinaryConditionalOperatorClass:
287       return PGOHash::BinaryConditionalOperator;
288     case Stmt::BinaryOperatorClass: {
289       const BinaryOperator *BO = cast<BinaryOperator>(S);
290       if (BO->getOpcode() == BO_LAnd)
291         return PGOHash::BinaryOperatorLAnd;
292       if (BO->getOpcode() == BO_LOr)
293         return PGOHash::BinaryOperatorLOr;
294       if (HashVersion >= PGO_HASH_V2) {
295         switch (BO->getOpcode()) {
296         default:
297           break;
298         case BO_LT:
299           return PGOHash::BinaryOperatorLT;
300         case BO_GT:
301           return PGOHash::BinaryOperatorGT;
302         case BO_LE:
303           return PGOHash::BinaryOperatorLE;
304         case BO_GE:
305           return PGOHash::BinaryOperatorGE;
306         case BO_EQ:
307           return PGOHash::BinaryOperatorEQ;
308         case BO_NE:
309           return PGOHash::BinaryOperatorNE;
310         }
311       }
312       break;
313     }
314     }
315 
316     if (HashVersion >= PGO_HASH_V2) {
317       switch (S->getStmtClass()) {
318       default:
319         break;
320       case Stmt::GotoStmtClass:
321         return PGOHash::GotoStmt;
322       case Stmt::IndirectGotoStmtClass:
323         return PGOHash::IndirectGotoStmt;
324       case Stmt::BreakStmtClass:
325         return PGOHash::BreakStmt;
326       case Stmt::ContinueStmtClass:
327         return PGOHash::ContinueStmt;
328       case Stmt::ReturnStmtClass:
329         return PGOHash::ReturnStmt;
330       case Stmt::CXXThrowExprClass:
331         return PGOHash::ThrowExpr;
332       case Stmt::UnaryOperatorClass: {
333         const UnaryOperator *UO = cast<UnaryOperator>(S);
334         if (UO->getOpcode() == UO_LNot)
335           return PGOHash::UnaryOperatorLNot;
336         break;
337       }
338       }
339     }
340 
341     return PGOHash::None;
342   }
343 };
344 
345 /// A StmtVisitor that propagates the raw counts through the AST and
346 /// records the count at statements where the value may change.
347 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
348   /// PGO state.
349   CodeGenPGO &PGO;
350 
351   /// A flag that is set when the current count should be recorded on the
352   /// next statement, such as at the exit of a loop.
353   bool RecordNextStmtCount;
354 
355   /// The count at the current location in the traversal.
356   uint64_t CurrentCount;
357 
358   /// The map of statements to count values.
359   llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
360 
361   /// BreakContinueStack - Keep counts of breaks and continues inside loops.
362   struct BreakContinue {
363     uint64_t BreakCount;
364     uint64_t ContinueCount;
BreakContinue__anon90a142290111::ComputeRegionCounts::BreakContinue365     BreakContinue() : BreakCount(0), ContinueCount(0) {}
366   };
367   SmallVector<BreakContinue, 8> BreakContinueStack;
368 
ComputeRegionCounts__anon90a142290111::ComputeRegionCounts369   ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
370                       CodeGenPGO &PGO)
371       : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
372 
RecordStmtCount__anon90a142290111::ComputeRegionCounts373   void RecordStmtCount(const Stmt *S) {
374     if (RecordNextStmtCount) {
375       CountMap[S] = CurrentCount;
376       RecordNextStmtCount = false;
377     }
378   }
379 
380   /// Set and return the current count.
setCount__anon90a142290111::ComputeRegionCounts381   uint64_t setCount(uint64_t Count) {
382     CurrentCount = Count;
383     return Count;
384   }
385 
VisitStmt__anon90a142290111::ComputeRegionCounts386   void VisitStmt(const Stmt *S) {
387     RecordStmtCount(S);
388     for (const Stmt *Child : S->children())
389       if (Child)
390         this->Visit(Child);
391   }
392 
VisitFunctionDecl__anon90a142290111::ComputeRegionCounts393   void VisitFunctionDecl(const FunctionDecl *D) {
394     // Counter tracks entry to the function body.
395     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
396     CountMap[D->getBody()] = BodyCount;
397     Visit(D->getBody());
398   }
399 
400   // Skip lambda expressions. We visit these as FunctionDecls when we're
401   // generating them and aren't interested in the body when generating a
402   // parent context.
VisitLambdaExpr__anon90a142290111::ComputeRegionCounts403   void VisitLambdaExpr(const LambdaExpr *LE) {}
404 
VisitCapturedDecl__anon90a142290111::ComputeRegionCounts405   void VisitCapturedDecl(const CapturedDecl *D) {
406     // Counter tracks entry to the capture body.
407     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
408     CountMap[D->getBody()] = BodyCount;
409     Visit(D->getBody());
410   }
411 
VisitObjCMethodDecl__anon90a142290111::ComputeRegionCounts412   void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
413     // Counter tracks entry to the method body.
414     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
415     CountMap[D->getBody()] = BodyCount;
416     Visit(D->getBody());
417   }
418 
VisitBlockDecl__anon90a142290111::ComputeRegionCounts419   void VisitBlockDecl(const BlockDecl *D) {
420     // Counter tracks entry to the block body.
421     uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
422     CountMap[D->getBody()] = BodyCount;
423     Visit(D->getBody());
424   }
425 
VisitReturnStmt__anon90a142290111::ComputeRegionCounts426   void VisitReturnStmt(const ReturnStmt *S) {
427     RecordStmtCount(S);
428     if (S->getRetValue())
429       Visit(S->getRetValue());
430     CurrentCount = 0;
431     RecordNextStmtCount = true;
432   }
433 
VisitCXXThrowExpr__anon90a142290111::ComputeRegionCounts434   void VisitCXXThrowExpr(const CXXThrowExpr *E) {
435     RecordStmtCount(E);
436     if (E->getSubExpr())
437       Visit(E->getSubExpr());
438     CurrentCount = 0;
439     RecordNextStmtCount = true;
440   }
441 
VisitGotoStmt__anon90a142290111::ComputeRegionCounts442   void VisitGotoStmt(const GotoStmt *S) {
443     RecordStmtCount(S);
444     CurrentCount = 0;
445     RecordNextStmtCount = true;
446   }
447 
VisitLabelStmt__anon90a142290111::ComputeRegionCounts448   void VisitLabelStmt(const LabelStmt *S) {
449     RecordNextStmtCount = false;
450     // Counter tracks the block following the label.
451     uint64_t BlockCount = setCount(PGO.getRegionCount(S));
452     CountMap[S] = BlockCount;
453     Visit(S->getSubStmt());
454   }
455 
VisitBreakStmt__anon90a142290111::ComputeRegionCounts456   void VisitBreakStmt(const BreakStmt *S) {
457     RecordStmtCount(S);
458     assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
459     BreakContinueStack.back().BreakCount += CurrentCount;
460     CurrentCount = 0;
461     RecordNextStmtCount = true;
462   }
463 
VisitContinueStmt__anon90a142290111::ComputeRegionCounts464   void VisitContinueStmt(const ContinueStmt *S) {
465     RecordStmtCount(S);
466     assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
467     BreakContinueStack.back().ContinueCount += CurrentCount;
468     CurrentCount = 0;
469     RecordNextStmtCount = true;
470   }
471 
VisitWhileStmt__anon90a142290111::ComputeRegionCounts472   void VisitWhileStmt(const WhileStmt *S) {
473     RecordStmtCount(S);
474     uint64_t ParentCount = CurrentCount;
475 
476     BreakContinueStack.push_back(BreakContinue());
477     // Visit the body region first so the break/continue adjustments can be
478     // included when visiting the condition.
479     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
480     CountMap[S->getBody()] = CurrentCount;
481     Visit(S->getBody());
482     uint64_t BackedgeCount = CurrentCount;
483 
484     // ...then go back and propagate counts through the condition. The count
485     // at the start of the condition is the sum of the incoming edges,
486     // the backedge from the end of the loop body, and the edges from
487     // continue statements.
488     BreakContinue BC = BreakContinueStack.pop_back_val();
489     uint64_t CondCount =
490         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
491     CountMap[S->getCond()] = CondCount;
492     Visit(S->getCond());
493     setCount(BC.BreakCount + CondCount - BodyCount);
494     RecordNextStmtCount = true;
495   }
496 
VisitDoStmt__anon90a142290111::ComputeRegionCounts497   void VisitDoStmt(const DoStmt *S) {
498     RecordStmtCount(S);
499     uint64_t LoopCount = PGO.getRegionCount(S);
500 
501     BreakContinueStack.push_back(BreakContinue());
502     // The count doesn't include the fallthrough from the parent scope. Add it.
503     uint64_t BodyCount = setCount(LoopCount + CurrentCount);
504     CountMap[S->getBody()] = BodyCount;
505     Visit(S->getBody());
506     uint64_t BackedgeCount = CurrentCount;
507 
508     BreakContinue BC = BreakContinueStack.pop_back_val();
509     // The count at the start of the condition is equal to the count at the
510     // end of the body, plus any continues.
511     uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
512     CountMap[S->getCond()] = CondCount;
513     Visit(S->getCond());
514     setCount(BC.BreakCount + CondCount - LoopCount);
515     RecordNextStmtCount = true;
516   }
517 
VisitForStmt__anon90a142290111::ComputeRegionCounts518   void VisitForStmt(const ForStmt *S) {
519     RecordStmtCount(S);
520     if (S->getInit())
521       Visit(S->getInit());
522 
523     uint64_t ParentCount = CurrentCount;
524 
525     BreakContinueStack.push_back(BreakContinue());
526     // Visit the body region first. (This is basically the same as a while
527     // loop; see further comments in VisitWhileStmt.)
528     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
529     CountMap[S->getBody()] = BodyCount;
530     Visit(S->getBody());
531     uint64_t BackedgeCount = CurrentCount;
532     BreakContinue BC = BreakContinueStack.pop_back_val();
533 
534     // The increment is essentially part of the body but it needs to include
535     // the count for all the continue statements.
536     if (S->getInc()) {
537       uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
538       CountMap[S->getInc()] = IncCount;
539       Visit(S->getInc());
540     }
541 
542     // ...then go back and propagate counts through the condition.
543     uint64_t CondCount =
544         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
545     if (S->getCond()) {
546       CountMap[S->getCond()] = CondCount;
547       Visit(S->getCond());
548     }
549     setCount(BC.BreakCount + CondCount - BodyCount);
550     RecordNextStmtCount = true;
551   }
552 
VisitCXXForRangeStmt__anon90a142290111::ComputeRegionCounts553   void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
554     RecordStmtCount(S);
555     if (S->getInit())
556       Visit(S->getInit());
557     Visit(S->getLoopVarStmt());
558     Visit(S->getRangeStmt());
559     Visit(S->getBeginStmt());
560     Visit(S->getEndStmt());
561 
562     uint64_t ParentCount = CurrentCount;
563     BreakContinueStack.push_back(BreakContinue());
564     // Visit the body region first. (This is basically the same as a while
565     // loop; see further comments in VisitWhileStmt.)
566     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
567     CountMap[S->getBody()] = BodyCount;
568     Visit(S->getBody());
569     uint64_t BackedgeCount = CurrentCount;
570     BreakContinue BC = BreakContinueStack.pop_back_val();
571 
572     // The increment is essentially part of the body but it needs to include
573     // the count for all the continue statements.
574     uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
575     CountMap[S->getInc()] = IncCount;
576     Visit(S->getInc());
577 
578     // ...then go back and propagate counts through the condition.
579     uint64_t CondCount =
580         setCount(ParentCount + BackedgeCount + BC.ContinueCount);
581     CountMap[S->getCond()] = CondCount;
582     Visit(S->getCond());
583     setCount(BC.BreakCount + CondCount - BodyCount);
584     RecordNextStmtCount = true;
585   }
586 
VisitObjCForCollectionStmt__anon90a142290111::ComputeRegionCounts587   void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
588     RecordStmtCount(S);
589     Visit(S->getElement());
590     uint64_t ParentCount = CurrentCount;
591     BreakContinueStack.push_back(BreakContinue());
592     // Counter tracks the body of the loop.
593     uint64_t BodyCount = setCount(PGO.getRegionCount(S));
594     CountMap[S->getBody()] = BodyCount;
595     Visit(S->getBody());
596     uint64_t BackedgeCount = CurrentCount;
597     BreakContinue BC = BreakContinueStack.pop_back_val();
598 
599     setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
600              BodyCount);
601     RecordNextStmtCount = true;
602   }
603 
VisitSwitchStmt__anon90a142290111::ComputeRegionCounts604   void VisitSwitchStmt(const SwitchStmt *S) {
605     RecordStmtCount(S);
606     if (S->getInit())
607       Visit(S->getInit());
608     Visit(S->getCond());
609     CurrentCount = 0;
610     BreakContinueStack.push_back(BreakContinue());
611     Visit(S->getBody());
612     // If the switch is inside a loop, add the continue counts.
613     BreakContinue BC = BreakContinueStack.pop_back_val();
614     if (!BreakContinueStack.empty())
615       BreakContinueStack.back().ContinueCount += BC.ContinueCount;
616     // Counter tracks the exit block of the switch.
617     setCount(PGO.getRegionCount(S));
618     RecordNextStmtCount = true;
619   }
620 
VisitSwitchCase__anon90a142290111::ComputeRegionCounts621   void VisitSwitchCase(const SwitchCase *S) {
622     RecordNextStmtCount = false;
623     // Counter for this particular case. This counts only jumps from the
624     // switch header and does not include fallthrough from the case before
625     // this one.
626     uint64_t CaseCount = PGO.getRegionCount(S);
627     setCount(CurrentCount + CaseCount);
628     // We need the count without fallthrough in the mapping, so it's more useful
629     // for branch probabilities.
630     CountMap[S] = CaseCount;
631     RecordNextStmtCount = true;
632     Visit(S->getSubStmt());
633   }
634 
VisitIfStmt__anon90a142290111::ComputeRegionCounts635   void VisitIfStmt(const IfStmt *S) {
636     RecordStmtCount(S);
637     uint64_t ParentCount = CurrentCount;
638     if (S->getInit())
639       Visit(S->getInit());
640     Visit(S->getCond());
641 
642     // Counter tracks the "then" part of an if statement. The count for
643     // the "else" part, if it exists, will be calculated from this counter.
644     uint64_t ThenCount = setCount(PGO.getRegionCount(S));
645     CountMap[S->getThen()] = ThenCount;
646     Visit(S->getThen());
647     uint64_t OutCount = CurrentCount;
648 
649     uint64_t ElseCount = ParentCount - ThenCount;
650     if (S->getElse()) {
651       setCount(ElseCount);
652       CountMap[S->getElse()] = ElseCount;
653       Visit(S->getElse());
654       OutCount += CurrentCount;
655     } else
656       OutCount += ElseCount;
657     setCount(OutCount);
658     RecordNextStmtCount = true;
659   }
660 
VisitCXXTryStmt__anon90a142290111::ComputeRegionCounts661   void VisitCXXTryStmt(const CXXTryStmt *S) {
662     RecordStmtCount(S);
663     Visit(S->getTryBlock());
664     for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
665       Visit(S->getHandler(I));
666     // Counter tracks the continuation block of the try statement.
667     setCount(PGO.getRegionCount(S));
668     RecordNextStmtCount = true;
669   }
670 
VisitCXXCatchStmt__anon90a142290111::ComputeRegionCounts671   void VisitCXXCatchStmt(const CXXCatchStmt *S) {
672     RecordNextStmtCount = false;
673     // Counter tracks the catch statement's handler block.
674     uint64_t CatchCount = setCount(PGO.getRegionCount(S));
675     CountMap[S] = CatchCount;
676     Visit(S->getHandlerBlock());
677   }
678 
VisitAbstractConditionalOperator__anon90a142290111::ComputeRegionCounts679   void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
680     RecordStmtCount(E);
681     uint64_t ParentCount = CurrentCount;
682     Visit(E->getCond());
683 
684     // Counter tracks the "true" part of a conditional operator. The
685     // count in the "false" part will be calculated from this counter.
686     uint64_t TrueCount = setCount(PGO.getRegionCount(E));
687     CountMap[E->getTrueExpr()] = TrueCount;
688     Visit(E->getTrueExpr());
689     uint64_t OutCount = CurrentCount;
690 
691     uint64_t FalseCount = setCount(ParentCount - TrueCount);
692     CountMap[E->getFalseExpr()] = FalseCount;
693     Visit(E->getFalseExpr());
694     OutCount += CurrentCount;
695 
696     setCount(OutCount);
697     RecordNextStmtCount = true;
698   }
699 
VisitBinLAnd__anon90a142290111::ComputeRegionCounts700   void VisitBinLAnd(const BinaryOperator *E) {
701     RecordStmtCount(E);
702     uint64_t ParentCount = CurrentCount;
703     Visit(E->getLHS());
704     // Counter tracks the right hand side of a logical and operator.
705     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
706     CountMap[E->getRHS()] = RHSCount;
707     Visit(E->getRHS());
708     setCount(ParentCount + RHSCount - CurrentCount);
709     RecordNextStmtCount = true;
710   }
711 
VisitBinLOr__anon90a142290111::ComputeRegionCounts712   void VisitBinLOr(const BinaryOperator *E) {
713     RecordStmtCount(E);
714     uint64_t ParentCount = CurrentCount;
715     Visit(E->getLHS());
716     // Counter tracks the right hand side of a logical or operator.
717     uint64_t RHSCount = setCount(PGO.getRegionCount(E));
718     CountMap[E->getRHS()] = RHSCount;
719     Visit(E->getRHS());
720     setCount(ParentCount + RHSCount - CurrentCount);
721     RecordNextStmtCount = true;
722   }
723 };
724 } // end anonymous namespace
725 
combine(HashType Type)726 void PGOHash::combine(HashType Type) {
727   // Check that we never combine 0 and only have six bits.
728   assert(Type && "Hash is invalid: unexpected type 0");
729   assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
730 
731   // Pass through MD5 if enough work has built up.
732   if (Count && Count % NumTypesPerWord == 0) {
733     using namespace llvm::support;
734     uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
735     MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
736     Working = 0;
737   }
738 
739   // Accumulate the current type.
740   ++Count;
741   Working = Working << NumBitsPerType | Type;
742 }
743 
finalize()744 uint64_t PGOHash::finalize() {
745   // Use Working as the hash directly if we never used MD5.
746   if (Count <= NumTypesPerWord)
747     // No need to byte swap here, since none of the math was endian-dependent.
748     // This number will be byte-swapped as required on endianness transitions,
749     // so we will see the same value on the other side.
750     return Working;
751 
752   // Check for remaining work in Working.
753   if (Working) {
754     // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
755     // is buggy because it converts a uint64_t into an array of uint8_t.
756     if (HashVersion < PGO_HASH_V3) {
757       MD5.update({(uint8_t)Working});
758     } else {
759       using namespace llvm::support;
760       uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
761       MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
762     }
763   }
764 
765   // Finalize the MD5 and return the hash.
766   llvm::MD5::MD5Result Result;
767   MD5.final(Result);
768   return Result.low();
769 }
770 
assignRegionCounters(GlobalDecl GD,llvm::Function * Fn)771 void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
772   const Decl *D = GD.getDecl();
773   if (!D->hasBody())
774     return;
775 
776   // Skip CUDA/HIP kernel launch stub functions.
777   if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
778       D->hasAttr<CUDAGlobalAttr>())
779     return;
780 
781   bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
782   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
783   if (!InstrumentRegions && !PGOReader)
784     return;
785   if (D->isImplicit())
786     return;
787   // Constructors and destructors may be represented by several functions in IR.
788   // If so, instrument only base variant, others are implemented by delegation
789   // to the base one, it would be counted twice otherwise.
790   if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
791     if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
792       if (GD.getCtorType() != Ctor_Base &&
793           CodeGenFunction::IsConstructorDelegationValid(CCD))
794         return;
795   }
796   if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
797     return;
798 
799   CGM.ClearUnusedCoverageMapping(D);
800   setFuncName(Fn);
801 
802   mapRegionCounters(D);
803   if (CGM.getCodeGenOpts().CoverageMapping)
804     emitCounterRegionMapping(D);
805   if (PGOReader) {
806     SourceManager &SM = CGM.getContext().getSourceManager();
807     loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
808     computeRegionCounts(D);
809     applyFunctionAttributes(PGOReader, Fn);
810   }
811 }
812 
mapRegionCounters(const Decl * D)813 void CodeGenPGO::mapRegionCounters(const Decl *D) {
814   // Use the latest hash version when inserting instrumentation, but use the
815   // version in the indexed profile if we're reading PGO data.
816   PGOHashVersion HashVersion = PGO_HASH_LATEST;
817   if (auto *PGOReader = CGM.getPGOReader())
818     HashVersion = getPGOHashVersion(PGOReader, CGM);
819 
820   RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
821   MapRegionCounters Walker(HashVersion, *RegionCounterMap);
822   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
823     Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
824   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
825     Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
826   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
827     Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
828   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
829     Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
830   assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
831   NumRegionCounters = Walker.NextCounter;
832   FunctionHash = Walker.Hash.finalize();
833 }
834 
skipRegionMappingForDecl(const Decl * D)835 bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
836   if (!D->getBody())
837     return true;
838 
839   // Skip host-only functions in the CUDA device compilation and device-only
840   // functions in the host compilation. Just roughly filter them out based on
841   // the function attributes. If there are effectively host-only or device-only
842   // ones, their coverage mapping may still be generated.
843   if (CGM.getLangOpts().CUDA &&
844       ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
845         !D->hasAttr<CUDAGlobalAttr>()) ||
846        (!CGM.getLangOpts().CUDAIsDevice &&
847         (D->hasAttr<CUDAGlobalAttr>() ||
848          (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
849     return true;
850 
851   // Don't map the functions in system headers.
852   const auto &SM = CGM.getContext().getSourceManager();
853   auto Loc = D->getBody()->getBeginLoc();
854   return SM.isInSystemHeader(Loc);
855 }
856 
emitCounterRegionMapping(const Decl * D)857 void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
858   if (skipRegionMappingForDecl(D))
859     return;
860 
861   std::string CoverageMapping;
862   llvm::raw_string_ostream OS(CoverageMapping);
863   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
864                                 CGM.getContext().getSourceManager(),
865                                 CGM.getLangOpts(), RegionCounterMap.get());
866   MappingGen.emitCounterMapping(D, OS);
867   OS.flush();
868 
869   if (CoverageMapping.empty())
870     return;
871 
872   CGM.getCoverageMapping()->addFunctionMappingRecord(
873       FuncNameVar, FuncName, FunctionHash, CoverageMapping);
874 }
875 
876 void
emitEmptyCounterMapping(const Decl * D,StringRef Name,llvm::GlobalValue::LinkageTypes Linkage)877 CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
878                                     llvm::GlobalValue::LinkageTypes Linkage) {
879   if (skipRegionMappingForDecl(D))
880     return;
881 
882   std::string CoverageMapping;
883   llvm::raw_string_ostream OS(CoverageMapping);
884   CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
885                                 CGM.getContext().getSourceManager(),
886                                 CGM.getLangOpts());
887   MappingGen.emitEmptyMapping(D, OS);
888   OS.flush();
889 
890   if (CoverageMapping.empty())
891     return;
892 
893   setFuncName(Name, Linkage);
894   CGM.getCoverageMapping()->addFunctionMappingRecord(
895       FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
896 }
897 
computeRegionCounts(const Decl * D)898 void CodeGenPGO::computeRegionCounts(const Decl *D) {
899   StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
900   ComputeRegionCounts Walker(*StmtCountMap, *this);
901   if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
902     Walker.VisitFunctionDecl(FD);
903   else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
904     Walker.VisitObjCMethodDecl(MD);
905   else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
906     Walker.VisitBlockDecl(BD);
907   else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
908     Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
909 }
910 
911 void
applyFunctionAttributes(llvm::IndexedInstrProfReader * PGOReader,llvm::Function * Fn)912 CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
913                                     llvm::Function *Fn) {
914   if (!haveRegionCounts())
915     return;
916 
917   uint64_t FunctionCount = getRegionCount(nullptr);
918   Fn->setEntryCount(FunctionCount);
919 }
920 
emitCounterIncrement(CGBuilderTy & Builder,const Stmt * S,llvm::Value * StepV)921 void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
922                                       llvm::Value *StepV) {
923   if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
924     return;
925   if (!Builder.GetInsertBlock())
926     return;
927 
928   unsigned Counter = (*RegionCounterMap)[S];
929   auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
930 
931   llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
932                          Builder.getInt64(FunctionHash),
933                          Builder.getInt32(NumRegionCounters),
934                          Builder.getInt32(Counter), StepV};
935   if (!StepV)
936     Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
937                        makeArrayRef(Args, 4));
938   else
939     Builder.CreateCall(
940         CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
941         makeArrayRef(Args));
942 }
943 
944 // This method either inserts a call to the profile run-time during
945 // instrumentation or puts profile data into metadata for PGO use.
valueProfile(CGBuilderTy & Builder,uint32_t ValueKind,llvm::Instruction * ValueSite,llvm::Value * ValuePtr)946 void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
947     llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
948 
949   if (!EnableValueProfiling)
950     return;
951 
952   if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
953     return;
954 
955   if (isa<llvm::Constant>(ValuePtr))
956     return;
957 
958   bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
959   if (InstrumentValueSites && RegionCounterMap) {
960     auto BuilderInsertPoint = Builder.saveIP();
961     Builder.SetInsertPoint(ValueSite);
962     llvm::Value *Args[5] = {
963         llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
964         Builder.getInt64(FunctionHash),
965         Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
966         Builder.getInt32(ValueKind),
967         Builder.getInt32(NumValueSites[ValueKind]++)
968     };
969     Builder.CreateCall(
970         CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
971     Builder.restoreIP(BuilderInsertPoint);
972     return;
973   }
974 
975   llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
976   if (PGOReader && haveRegionCounts()) {
977     // We record the top most called three functions at each call site.
978     // Profile metadata contains "VP" string identifying this metadata
979     // as value profiling data, then a uint32_t value for the value profiling
980     // kind, a uint64_t value for the total number of times the call is
981     // executed, followed by the function hash and execution count (uint64_t)
982     // pairs for each function.
983     if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
984       return;
985 
986     llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
987                             (llvm::InstrProfValueKind)ValueKind,
988                             NumValueSites[ValueKind]);
989 
990     NumValueSites[ValueKind]++;
991   }
992 }
993 
loadRegionCounts(llvm::IndexedInstrProfReader * PGOReader,bool IsInMainFile)994 void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
995                                   bool IsInMainFile) {
996   CGM.getPGOStats().addVisited(IsInMainFile);
997   RegionCounts.clear();
998   llvm::Expected<llvm::InstrProfRecord> RecordExpected =
999       PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1000   if (auto E = RecordExpected.takeError()) {
1001     auto IPE = llvm::InstrProfError::take(std::move(E));
1002     if (IPE == llvm::instrprof_error::unknown_function)
1003       CGM.getPGOStats().addMissing(IsInMainFile);
1004     else if (IPE == llvm::instrprof_error::hash_mismatch)
1005       CGM.getPGOStats().addMismatched(IsInMainFile);
1006     else if (IPE == llvm::instrprof_error::malformed)
1007       // TODO: Consider a more specific warning for this case.
1008       CGM.getPGOStats().addMismatched(IsInMainFile);
1009     return;
1010   }
1011   ProfRecord =
1012       std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1013   RegionCounts = ProfRecord->Counts;
1014 }
1015 
1016 /// Calculate what to divide by to scale weights.
1017 ///
1018 /// Given the maximum weight, calculate a divisor that will scale all the
1019 /// weights to strictly less than UINT32_MAX.
calculateWeightScale(uint64_t MaxWeight)1020 static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1021   return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1022 }
1023 
1024 /// Scale an individual branch weight (and add 1).
1025 ///
1026 /// Scale a 64-bit weight down to 32-bits using \c Scale.
1027 ///
1028 /// According to Laplace's Rule of Succession, it is better to compute the
1029 /// weight based on the count plus 1, so universally add 1 to the value.
1030 ///
1031 /// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1032 /// greater than \c Weight.
scaleBranchWeight(uint64_t Weight,uint64_t Scale)1033 static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1034   assert(Scale && "scale by 0?");
1035   uint64_t Scaled = Weight / Scale + 1;
1036   assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1037   return Scaled;
1038 }
1039 
createProfileWeights(uint64_t TrueCount,uint64_t FalseCount) const1040 llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1041                                                     uint64_t FalseCount) const {
1042   // Check for empty weights.
1043   if (!TrueCount && !FalseCount)
1044     return nullptr;
1045 
1046   // Calculate how to scale down to 32-bits.
1047   uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1048 
1049   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1050   return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1051                                       scaleBranchWeight(FalseCount, Scale));
1052 }
1053 
1054 llvm::MDNode *
createProfileWeights(ArrayRef<uint64_t> Weights) const1055 CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1056   // We need at least two elements to create meaningful weights.
1057   if (Weights.size() < 2)
1058     return nullptr;
1059 
1060   // Check for empty weights.
1061   uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1062   if (MaxWeight == 0)
1063     return nullptr;
1064 
1065   // Calculate how to scale down to 32-bits.
1066   uint64_t Scale = calculateWeightScale(MaxWeight);
1067 
1068   SmallVector<uint32_t, 16> ScaledWeights;
1069   ScaledWeights.reserve(Weights.size());
1070   for (uint64_t W : Weights)
1071     ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1072 
1073   llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1074   return MDHelper.createBranchWeights(ScaledWeights);
1075 }
1076 
1077 llvm::MDNode *
createProfileWeightsForLoop(const Stmt * Cond,uint64_t LoopCount) const1078 CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1079                                              uint64_t LoopCount) const {
1080   if (!PGO.haveRegionCounts())
1081     return nullptr;
1082   Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1083   if (!CondCount || *CondCount == 0)
1084     return nullptr;
1085   return createProfileWeights(LoopCount,
1086                               std::max(*CondCount, LoopCount) - LoopCount);
1087 }
1088