1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This file contains the AArch64 / Cortex-A57 specific register allocation
10 // constraints for use by the PBQP register allocator.
11 //
12 // It is essentially a transcription of what is contained in
13 // AArch64A57FPLoadBalancing, which tries to use a balanced
14 // mix of odd and even D-registers when performing a critical sequence of
15 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
16 //===----------------------------------------------------------------------===//
17
18 #define DEBUG_TYPE "aarch64-pbqp"
19
20 #include "AArch64.h"
21 #include "AArch64PBQPRegAlloc.h"
22 #include "AArch64RegisterInfo.h"
23 #include "llvm/CodeGen/LiveIntervalAnalysis.h"
24 #include "llvm/CodeGen/MachineBasicBlock.h"
25 #include "llvm/CodeGen/MachineFunction.h"
26 #include "llvm/CodeGen/MachineRegisterInfo.h"
27 #include "llvm/CodeGen/RegAllocPBQP.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
31
32 using namespace llvm;
33
34 namespace {
35
36 #ifndef NDEBUG
isFPReg(unsigned reg)37 bool isFPReg(unsigned reg) {
38 return AArch64::FPR32RegClass.contains(reg) ||
39 AArch64::FPR64RegClass.contains(reg) ||
40 AArch64::FPR128RegClass.contains(reg);
41 }
42 #endif
43
isOdd(unsigned reg)44 bool isOdd(unsigned reg) {
45 switch (reg) {
46 default:
47 llvm_unreachable("Register is not from the expected class !");
48 case AArch64::S1:
49 case AArch64::S3:
50 case AArch64::S5:
51 case AArch64::S7:
52 case AArch64::S9:
53 case AArch64::S11:
54 case AArch64::S13:
55 case AArch64::S15:
56 case AArch64::S17:
57 case AArch64::S19:
58 case AArch64::S21:
59 case AArch64::S23:
60 case AArch64::S25:
61 case AArch64::S27:
62 case AArch64::S29:
63 case AArch64::S31:
64 case AArch64::D1:
65 case AArch64::D3:
66 case AArch64::D5:
67 case AArch64::D7:
68 case AArch64::D9:
69 case AArch64::D11:
70 case AArch64::D13:
71 case AArch64::D15:
72 case AArch64::D17:
73 case AArch64::D19:
74 case AArch64::D21:
75 case AArch64::D23:
76 case AArch64::D25:
77 case AArch64::D27:
78 case AArch64::D29:
79 case AArch64::D31:
80 case AArch64::Q1:
81 case AArch64::Q3:
82 case AArch64::Q5:
83 case AArch64::Q7:
84 case AArch64::Q9:
85 case AArch64::Q11:
86 case AArch64::Q13:
87 case AArch64::Q15:
88 case AArch64::Q17:
89 case AArch64::Q19:
90 case AArch64::Q21:
91 case AArch64::Q23:
92 case AArch64::Q25:
93 case AArch64::Q27:
94 case AArch64::Q29:
95 case AArch64::Q31:
96 return true;
97 case AArch64::S0:
98 case AArch64::S2:
99 case AArch64::S4:
100 case AArch64::S6:
101 case AArch64::S8:
102 case AArch64::S10:
103 case AArch64::S12:
104 case AArch64::S14:
105 case AArch64::S16:
106 case AArch64::S18:
107 case AArch64::S20:
108 case AArch64::S22:
109 case AArch64::S24:
110 case AArch64::S26:
111 case AArch64::S28:
112 case AArch64::S30:
113 case AArch64::D0:
114 case AArch64::D2:
115 case AArch64::D4:
116 case AArch64::D6:
117 case AArch64::D8:
118 case AArch64::D10:
119 case AArch64::D12:
120 case AArch64::D14:
121 case AArch64::D16:
122 case AArch64::D18:
123 case AArch64::D20:
124 case AArch64::D22:
125 case AArch64::D24:
126 case AArch64::D26:
127 case AArch64::D28:
128 case AArch64::D30:
129 case AArch64::Q0:
130 case AArch64::Q2:
131 case AArch64::Q4:
132 case AArch64::Q6:
133 case AArch64::Q8:
134 case AArch64::Q10:
135 case AArch64::Q12:
136 case AArch64::Q14:
137 case AArch64::Q16:
138 case AArch64::Q18:
139 case AArch64::Q20:
140 case AArch64::Q22:
141 case AArch64::Q24:
142 case AArch64::Q26:
143 case AArch64::Q28:
144 case AArch64::Q30:
145 return false;
146
147 }
148 }
149
haveSameParity(unsigned reg1,unsigned reg2)150 bool haveSameParity(unsigned reg1, unsigned reg2) {
151 assert(isFPReg(reg1) && "Expecting an FP register for reg1");
152 assert(isFPReg(reg2) && "Expecting an FP register for reg2");
153
154 return isOdd(reg1) == isOdd(reg2);
155 }
156
157 }
158
addIntraChainConstraint(PBQPRAGraph & G,unsigned Rd,unsigned Ra)159 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
160 unsigned Ra) {
161 if (Rd == Ra)
162 return false;
163
164 LiveIntervals &LIs = G.getMetadata().LIS;
165
166 if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
167 DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
168 << '\n');
169 DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
170 << '\n');
171 return false;
172 }
173
174 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
175 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
176
177 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
178 &G.getNodeMetadata(node1).getAllowedRegs();
179 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
180 &G.getNodeMetadata(node2).getAllowedRegs();
181
182 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
183
184 // The edge does not exist. Create one with the appropriate interference
185 // costs.
186 if (edge == G.invalidEdgeId()) {
187 const LiveInterval &ld = LIs.getInterval(Rd);
188 const LiveInterval &la = LIs.getInterval(Ra);
189 bool livesOverlap = ld.overlaps(la);
190
191 PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
192 vRaAllowed->size() + 1, 0);
193 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
194 unsigned pRd = (*vRdAllowed)[i];
195 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
196 unsigned pRa = (*vRaAllowed)[j];
197 if (livesOverlap && TRI->regsOverlap(pRd, pRa))
198 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
199 else
200 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
201 }
202 }
203 G.addEdge(node1, node2, std::move(costs));
204 return true;
205 }
206
207 if (G.getEdgeNode1Id(edge) == node2) {
208 std::swap(node1, node2);
209 std::swap(vRdAllowed, vRaAllowed);
210 }
211
212 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
213 PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
214 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
215 unsigned pRd = (*vRdAllowed)[i];
216
217 // Get the maximum cost (excluding unallocatable reg) for same parity
218 // registers
219 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
220 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
221 unsigned pRa = (*vRaAllowed)[j];
222 if (haveSameParity(pRd, pRa))
223 if (costs[i + 1][j + 1] !=
224 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
225 costs[i + 1][j + 1] > sameParityMax)
226 sameParityMax = costs[i + 1][j + 1];
227 }
228
229 // Ensure all registers with a different parity have a higher cost
230 // than sameParityMax
231 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
232 unsigned pRa = (*vRaAllowed)[j];
233 if (!haveSameParity(pRd, pRa))
234 if (sameParityMax > costs[i + 1][j + 1])
235 costs[i + 1][j + 1] = sameParityMax + 1.0;
236 }
237 }
238 G.updateEdgeCosts(edge, std::move(costs));
239
240 return true;
241 }
242
addInterChainConstraint(PBQPRAGraph & G,unsigned Rd,unsigned Ra)243 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
244 unsigned Ra) {
245 LiveIntervals &LIs = G.getMetadata().LIS;
246
247 // Do some Chain management
248 if (Chains.count(Ra)) {
249 if (Rd != Ra) {
250 DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to "
251 << PrintReg(Rd, TRI) << '\n';);
252 Chains.remove(Ra);
253 Chains.insert(Rd);
254 }
255 } else {
256 DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI)
257 << '\n';);
258 Chains.insert(Rd);
259 }
260
261 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
262
263 const LiveInterval &ld = LIs.getInterval(Rd);
264 for (auto r : Chains) {
265 // Skip self
266 if (r == Rd)
267 continue;
268
269 const LiveInterval &lr = LIs.getInterval(r);
270 if (ld.overlaps(lr)) {
271 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
272 &G.getNodeMetadata(node1).getAllowedRegs();
273
274 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
275 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
276 &G.getNodeMetadata(node2).getAllowedRegs();
277
278 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
279 assert(edge != G.invalidEdgeId() &&
280 "PBQP error ! The edge should exist !");
281
282 DEBUG(dbgs() << "Refining constraint !\n";);
283
284 if (G.getEdgeNode1Id(edge) == node2) {
285 std::swap(node1, node2);
286 std::swap(vRdAllowed, vRrAllowed);
287 }
288
289 // Enforce that cost is higher with all other Chains of the same parity
290 PBQP::Matrix costs(G.getEdgeCosts(edge));
291 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
292 unsigned pRd = (*vRdAllowed)[i];
293
294 // Get the maximum cost (excluding unallocatable reg) for all other
295 // parity registers
296 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
297 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
298 unsigned pRa = (*vRrAllowed)[j];
299 if (!haveSameParity(pRd, pRa))
300 if (costs[i + 1][j + 1] !=
301 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
302 costs[i + 1][j + 1] > sameParityMax)
303 sameParityMax = costs[i + 1][j + 1];
304 }
305
306 // Ensure all registers with same parity have a higher cost
307 // than sameParityMax
308 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
309 unsigned pRa = (*vRrAllowed)[j];
310 if (haveSameParity(pRd, pRa))
311 if (sameParityMax > costs[i + 1][j + 1])
312 costs[i + 1][j + 1] = sameParityMax + 1.0;
313 }
314 }
315 G.updateEdgeCosts(edge, std::move(costs));
316 }
317 }
318 }
319
regJustKilledBefore(const LiveIntervals & LIs,unsigned reg,const MachineInstr & MI)320 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
321 const MachineInstr &MI) {
322 const LiveInterval &LI = LIs.getInterval(reg);
323 SlotIndex SI = LIs.getInstructionIndex(&MI);
324 return LI.expiredAt(SI);
325 }
326
apply(PBQPRAGraph & G)327 void A57ChainingConstraint::apply(PBQPRAGraph &G) {
328 const MachineFunction &MF = G.getMetadata().MF;
329 LiveIntervals &LIs = G.getMetadata().LIS;
330
331 TRI = MF.getSubtarget().getRegisterInfo();
332 DEBUG(MF.dump());
333
334 for (const auto &MBB: MF) {
335 Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
336
337 for (const auto &MI: MBB) {
338
339 // Forget Chains which have expired
340 for (auto r : Chains) {
341 SmallVector<unsigned, 8> toDel;
342 if(regJustKilledBefore(LIs, r, MI)) {
343 DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
344 MI.print(dbgs()););
345 toDel.push_back(r);
346 }
347
348 while (!toDel.empty()) {
349 Chains.remove(toDel.back());
350 toDel.pop_back();
351 }
352 }
353
354 switch (MI.getOpcode()) {
355 case AArch64::FMSUBSrrr:
356 case AArch64::FMADDSrrr:
357 case AArch64::FNMSUBSrrr:
358 case AArch64::FNMADDSrrr:
359 case AArch64::FMSUBDrrr:
360 case AArch64::FMADDDrrr:
361 case AArch64::FNMSUBDrrr:
362 case AArch64::FNMADDDrrr: {
363 unsigned Rd = MI.getOperand(0).getReg();
364 unsigned Ra = MI.getOperand(3).getReg();
365
366 if (addIntraChainConstraint(G, Rd, Ra))
367 addInterChainConstraint(G, Rd, Ra);
368 break;
369 }
370
371 case AArch64::FMLAv2f32:
372 case AArch64::FMLSv2f32: {
373 unsigned Rd = MI.getOperand(0).getReg();
374 addInterChainConstraint(G, Rd, Rd);
375 break;
376 }
377
378 default:
379 break;
380 }
381 }
382 }
383 }
384