1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "GCNRegPressure.h"
11 #include "AMDGPUSubtarget.h"
12 #include "SIRegisterInfo.h"
13 #include "llvm/ADT/SmallVector.h"
14 #include "llvm/CodeGen/LiveInterval.h"
15 #include "llvm/CodeGen/LiveIntervals.h"
16 #include "llvm/CodeGen/MachineInstr.h"
17 #include "llvm/CodeGen/MachineOperand.h"
18 #include "llvm/CodeGen/MachineRegisterInfo.h"
19 #include "llvm/CodeGen/RegisterPressure.h"
20 #include "llvm/CodeGen/SlotIndexes.h"
21 #include "llvm/CodeGen/TargetRegisterInfo.h"
22 #include "llvm/Config/llvm-config.h"
23 #include "llvm/MC/LaneBitmask.h"
24 #include "llvm/Support/Compiler.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <algorithm>
29 #include <cassert>
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "machine-scheduler"
34 
35 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
36 LLVM_DUMP_METHOD
printLivesAt(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)37 void llvm::printLivesAt(SlotIndex SI,
38                         const LiveIntervals &LIS,
39                         const MachineRegisterInfo &MRI) {
40   dbgs() << "Live regs at " << SI << ": "
41          << *LIS.getInstructionFromIndex(SI);
42   unsigned Num = 0;
43   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
44     const unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
45     if (!LIS.hasInterval(Reg))
46       continue;
47     const auto &LI = LIS.getInterval(Reg);
48     if (LI.hasSubRanges()) {
49       bool firstTime = true;
50       for (const auto &S : LI.subranges()) {
51         if (!S.liveAt(SI)) continue;
52         if (firstTime) {
53           dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
54                  << '\n';
55           firstTime = false;
56         }
57         dbgs() << "  " << S << '\n';
58         ++Num;
59       }
60     } else if (LI.liveAt(SI)) {
61       dbgs() << "  " << LI << '\n';
62       ++Num;
63     }
64   }
65   if (!Num) dbgs() << "  <none>\n";
66 }
67 
isEqual(const GCNRPTracker::LiveRegSet & S1,const GCNRPTracker::LiveRegSet & S2)68 static bool isEqual(const GCNRPTracker::LiveRegSet &S1,
69                     const GCNRPTracker::LiveRegSet &S2) {
70   if (S1.size() != S2.size())
71     return false;
72 
73   for (const auto &P : S1) {
74     auto I = S2.find(P.first);
75     if (I == S2.end() || I->second != P.second)
76       return false;
77   }
78   return true;
79 }
80 #endif
81 
82 ///////////////////////////////////////////////////////////////////////////////
83 // GCNRegPressure
84 
getRegKind(unsigned Reg,const MachineRegisterInfo & MRI)85 unsigned GCNRegPressure::getRegKind(unsigned Reg,
86                                     const MachineRegisterInfo &MRI) {
87   assert(TargetRegisterInfo::isVirtualRegister(Reg));
88   const auto RC = MRI.getRegClass(Reg);
89   auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
90   return STI->isSGPRClass(RC) ?
91     (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
92     (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
93 }
94 
inc(unsigned Reg,LaneBitmask PrevMask,LaneBitmask NewMask,const MachineRegisterInfo & MRI)95 void GCNRegPressure::inc(unsigned Reg,
96                          LaneBitmask PrevMask,
97                          LaneBitmask NewMask,
98                          const MachineRegisterInfo &MRI) {
99   if (NewMask == PrevMask)
100     return;
101 
102   int Sign = 1;
103   if (NewMask < PrevMask) {
104     std::swap(NewMask, PrevMask);
105     Sign = -1;
106   }
107 #ifndef NDEBUG
108   const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg);
109 #endif
110   switch (auto Kind = getRegKind(Reg, MRI)) {
111   case SGPR32:
112   case VGPR32:
113     assert(PrevMask.none() && NewMask == MaxMask);
114     Value[Kind] += Sign;
115     break;
116 
117   case SGPR_TUPLE:
118   case VGPR_TUPLE:
119     assert(NewMask < MaxMask || NewMask == MaxMask);
120     assert(PrevMask < NewMask);
121 
122     Value[Kind == SGPR_TUPLE ? SGPR32 : VGPR32] +=
123       Sign * (~PrevMask & NewMask).getNumLanes();
124 
125     if (PrevMask.none()) {
126       assert(NewMask.any());
127       Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
128     }
129     break;
130 
131   default: llvm_unreachable("Unknown register kind");
132   }
133 }
134 
less(const GCNSubtarget & ST,const GCNRegPressure & O,unsigned MaxOccupancy) const135 bool GCNRegPressure::less(const GCNSubtarget &ST,
136                           const GCNRegPressure& O,
137                           unsigned MaxOccupancy) const {
138   const auto SGPROcc = std::min(MaxOccupancy,
139                                 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
140   const auto VGPROcc = std::min(MaxOccupancy,
141                                 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
142   const auto OtherSGPROcc = std::min(MaxOccupancy,
143                                 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
144   const auto OtherVGPROcc = std::min(MaxOccupancy,
145                                 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
146 
147   const auto Occ = std::min(SGPROcc, VGPROcc);
148   const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
149   if (Occ != OtherOcc)
150     return Occ > OtherOcc;
151 
152   bool SGPRImportant = SGPROcc < VGPROcc;
153   const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
154 
155   // if both pressures disagree on what is more important compare vgprs
156   if (SGPRImportant != OtherSGPRImportant) {
157     SGPRImportant = false;
158   }
159 
160   // compare large regs pressure
161   bool SGPRFirst = SGPRImportant;
162   for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
163     if (SGPRFirst) {
164       auto SW = getSGPRTuplesWeight();
165       auto OtherSW = O.getSGPRTuplesWeight();
166       if (SW != OtherSW)
167         return SW < OtherSW;
168     } else {
169       auto VW = getVGPRTuplesWeight();
170       auto OtherVW = O.getVGPRTuplesWeight();
171       if (VW != OtherVW)
172         return VW < OtherVW;
173     }
174   }
175   return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
176                          (getVGPRNum() < O.getVGPRNum());
177 }
178 
179 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
180 LLVM_DUMP_METHOD
print(raw_ostream & OS,const GCNSubtarget * ST) const181 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
182   OS << "VGPRs: " << getVGPRNum();
183   if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
184   OS << ", SGPRs: " << getSGPRNum();
185   if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
186   OS << ", LVGPR WT: " << getVGPRTuplesWeight()
187      << ", LSGPR WT: " << getSGPRTuplesWeight();
188   if (ST) OS << " -> Occ: " << getOccupancy(*ST);
189   OS << '\n';
190 }
191 #endif
192 
getDefRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI)193 static LaneBitmask getDefRegMask(const MachineOperand &MO,
194                                  const MachineRegisterInfo &MRI) {
195   assert(MO.isDef() && MO.isReg() &&
196     TargetRegisterInfo::isVirtualRegister(MO.getReg()));
197 
198   // We don't rely on read-undef flag because in case of tentative schedule
199   // tracking it isn't set correctly yet. This works correctly however since
200   // use mask has been tracked before using LIS.
201   return MO.getSubReg() == 0 ?
202     MRI.getMaxLaneMaskForVReg(MO.getReg()) :
203     MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
204 }
205 
getUsedRegMask(const MachineOperand & MO,const MachineRegisterInfo & MRI,const LiveIntervals & LIS)206 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
207                                   const MachineRegisterInfo &MRI,
208                                   const LiveIntervals &LIS) {
209   assert(MO.isUse() && MO.isReg() &&
210          TargetRegisterInfo::isVirtualRegister(MO.getReg()));
211 
212   if (auto SubReg = MO.getSubReg())
213     return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
214 
215   auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
216   if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs
217     return MaxMask;
218 
219   // For a tentative schedule LIS isn't updated yet but livemask should remain
220   // the same on any schedule. Subreg defs can be reordered but they all must
221   // dominate uses anyway.
222   auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
223   return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
224 }
225 
226 static SmallVector<RegisterMaskPair, 8>
collectVirtualRegUses(const MachineInstr & MI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)227 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
228                       const MachineRegisterInfo &MRI) {
229   SmallVector<RegisterMaskPair, 8> Res;
230   for (const auto &MO : MI.operands()) {
231     if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()))
232       continue;
233     if (!MO.isUse() || !MO.readsReg())
234       continue;
235 
236     auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
237 
238     auto Reg = MO.getReg();
239     auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
240       return RM.RegUnit == Reg;
241     });
242     if (I != Res.end())
243       I->LaneMask |= UsedMask;
244     else
245       Res.push_back(RegisterMaskPair(Reg, UsedMask));
246   }
247   return Res;
248 }
249 
250 ///////////////////////////////////////////////////////////////////////////////
251 // GCNRPTracker
252 
getLiveLaneMask(unsigned Reg,SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)253 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
254                                   SlotIndex SI,
255                                   const LiveIntervals &LIS,
256                                   const MachineRegisterInfo &MRI) {
257   LaneBitmask LiveMask;
258   const auto &LI = LIS.getInterval(Reg);
259   if (LI.hasSubRanges()) {
260     for (const auto &S : LI.subranges())
261       if (S.liveAt(SI)) {
262         LiveMask |= S.LaneMask;
263         assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
264                LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
265       }
266   } else if (LI.liveAt(SI)) {
267     LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
268   }
269   return LiveMask;
270 }
271 
getLiveRegs(SlotIndex SI,const LiveIntervals & LIS,const MachineRegisterInfo & MRI)272 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
273                                            const LiveIntervals &LIS,
274                                            const MachineRegisterInfo &MRI) {
275   GCNRPTracker::LiveRegSet LiveRegs;
276   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
277     auto Reg = TargetRegisterInfo::index2VirtReg(I);
278     if (!LIS.hasInterval(Reg))
279       continue;
280     auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
281     if (LiveMask.any())
282       LiveRegs[Reg] = LiveMask;
283   }
284   return LiveRegs;
285 }
286 
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy,bool After)287 void GCNRPTracker::reset(const MachineInstr &MI,
288                          const LiveRegSet *LiveRegsCopy,
289                          bool After) {
290   const MachineFunction &MF = *MI.getMF();
291   MRI = &MF.getRegInfo();
292   if (LiveRegsCopy) {
293     if (&LiveRegs != LiveRegsCopy)
294       LiveRegs = *LiveRegsCopy;
295   } else {
296     LiveRegs = After ? getLiveRegsAfter(MI, LIS)
297                      : getLiveRegsBefore(MI, LIS);
298   }
299 
300   MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
301 }
302 
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)303 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
304                                const LiveRegSet *LiveRegsCopy) {
305   GCNRPTracker::reset(MI, LiveRegsCopy, true);
306 }
307 
recede(const MachineInstr & MI)308 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
309   assert(MRI && "call reset first");
310 
311   LastTrackedMI = &MI;
312 
313   if (MI.isDebugInstr())
314     return;
315 
316   auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
317 
318   // calc pressure at the MI (defs + uses)
319   auto AtMIPressure = CurPressure;
320   for (const auto &U : RegUses) {
321     auto LiveMask = LiveRegs[U.RegUnit];
322     AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
323   }
324   // update max pressure
325   MaxPressure = max(AtMIPressure, MaxPressure);
326 
327   for (const auto &MO : MI.defs()) {
328     if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()) ||
329          MO.isDead())
330       continue;
331 
332     auto Reg = MO.getReg();
333     auto I = LiveRegs.find(Reg);
334     if (I == LiveRegs.end())
335       continue;
336     auto &LiveMask = I->second;
337     auto PrevMask = LiveMask;
338     LiveMask &= ~getDefRegMask(MO, *MRI);
339     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
340     if (LiveMask.none())
341       LiveRegs.erase(I);
342   }
343   for (const auto &U : RegUses) {
344     auto &LiveMask = LiveRegs[U.RegUnit];
345     auto PrevMask = LiveMask;
346     LiveMask |= U.LaneMask;
347     CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
348   }
349   assert(CurPressure == getRegPressure(*MRI, LiveRegs));
350 }
351 
reset(const MachineInstr & MI,const LiveRegSet * LiveRegsCopy)352 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
353                                  const LiveRegSet *LiveRegsCopy) {
354   MRI = &MI.getParent()->getParent()->getRegInfo();
355   LastTrackedMI = nullptr;
356   MBBEnd = MI.getParent()->end();
357   NextMI = &MI;
358   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
359   if (NextMI == MBBEnd)
360     return false;
361   GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
362   return true;
363 }
364 
advanceBeforeNext()365 bool GCNDownwardRPTracker::advanceBeforeNext() {
366   assert(MRI && "call reset first");
367 
368   NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
369   if (NextMI == MBBEnd)
370     return false;
371 
372   SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
373   assert(SI.isValid());
374 
375   // Remove dead registers or mask bits.
376   for (auto &It : LiveRegs) {
377     const LiveInterval &LI = LIS.getInterval(It.first);
378     if (LI.hasSubRanges()) {
379       for (const auto &S : LI.subranges()) {
380         if (!S.liveAt(SI)) {
381           auto PrevMask = It.second;
382           It.second &= ~S.LaneMask;
383           CurPressure.inc(It.first, PrevMask, It.second, *MRI);
384         }
385       }
386     } else if (!LI.liveAt(SI)) {
387       auto PrevMask = It.second;
388       It.second = LaneBitmask::getNone();
389       CurPressure.inc(It.first, PrevMask, It.second, *MRI);
390     }
391     if (It.second.none())
392       LiveRegs.erase(It.first);
393   }
394 
395   MaxPressure = max(MaxPressure, CurPressure);
396 
397   return true;
398 }
399 
advanceToNext()400 void GCNDownwardRPTracker::advanceToNext() {
401   LastTrackedMI = &*NextMI++;
402 
403   // Add new registers or mask bits.
404   for (const auto &MO : LastTrackedMI->defs()) {
405     if (!MO.isReg())
406       continue;
407     unsigned Reg = MO.getReg();
408     if (!TargetRegisterInfo::isVirtualRegister(Reg))
409       continue;
410     auto &LiveMask = LiveRegs[Reg];
411     auto PrevMask = LiveMask;
412     LiveMask |= getDefRegMask(MO, *MRI);
413     CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
414   }
415 
416   MaxPressure = max(MaxPressure, CurPressure);
417 }
418 
advance()419 bool GCNDownwardRPTracker::advance() {
420   // If we have just called reset live set is actual.
421   if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
422     return false;
423   advanceToNext();
424   return true;
425 }
426 
advance(MachineBasicBlock::const_iterator End)427 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
428   while (NextMI != End)
429     if (!advance()) return false;
430   return true;
431 }
432 
advance(MachineBasicBlock::const_iterator Begin,MachineBasicBlock::const_iterator End,const LiveRegSet * LiveRegsCopy)433 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
434                                    MachineBasicBlock::const_iterator End,
435                                    const LiveRegSet *LiveRegsCopy) {
436   reset(*Begin, LiveRegsCopy);
437   return advance(End);
438 }
439 
440 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
441 LLVM_DUMP_METHOD
reportMismatch(const GCNRPTracker::LiveRegSet & LISLR,const GCNRPTracker::LiveRegSet & TrackedLR,const TargetRegisterInfo * TRI)442 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
443                            const GCNRPTracker::LiveRegSet &TrackedLR,
444                            const TargetRegisterInfo *TRI) {
445   for (auto const &P : TrackedLR) {
446     auto I = LISLR.find(P.first);
447     if (I == LISLR.end()) {
448       dbgs() << "  " << printReg(P.first, TRI)
449              << ":L" << PrintLaneMask(P.second)
450              << " isn't found in LIS reported set\n";
451     }
452     else if (I->second != P.second) {
453       dbgs() << "  " << printReg(P.first, TRI)
454         << " masks doesn't match: LIS reported "
455         << PrintLaneMask(I->second)
456         << ", tracked "
457         << PrintLaneMask(P.second)
458         << '\n';
459     }
460   }
461   for (auto const &P : LISLR) {
462     auto I = TrackedLR.find(P.first);
463     if (I == TrackedLR.end()) {
464       dbgs() << "  " << printReg(P.first, TRI)
465              << ":L" << PrintLaneMask(P.second)
466              << " isn't found in tracked set\n";
467     }
468   }
469 }
470 
isValid() const471 bool GCNUpwardRPTracker::isValid() const {
472   const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
473   const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
474   const auto &TrackedLR = LiveRegs;
475 
476   if (!isEqual(LISLR, TrackedLR)) {
477     dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
478               " LIS reported livesets mismatch:\n";
479     printLivesAt(SI, LIS, *MRI);
480     reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
481     return false;
482   }
483 
484   auto LISPressure = getRegPressure(*MRI, LISLR);
485   if (LISPressure != CurPressure) {
486     dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
487     CurPressure.print(dbgs());
488     dbgs() << "LIS rpt: ";
489     LISPressure.print(dbgs());
490     return false;
491   }
492   return true;
493 }
494 
printLiveRegs(raw_ostream & OS,const LiveRegSet & LiveRegs,const MachineRegisterInfo & MRI)495 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
496                                  const MachineRegisterInfo &MRI) {
497   const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
498   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
499     unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
500     auto It = LiveRegs.find(Reg);
501     if (It != LiveRegs.end() && It->second.any())
502       OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
503          << PrintLaneMask(It->second);
504   }
505   OS << '\n';
506 }
507 #endif
508