1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass removes performs the following type substitution on all
12 /// non-compute shaders:
13 ///
14 /// v16i8 => i128
15 ///   - v16i8 is used for constant memory resource descriptors.  This type is
16 ///      legal for some compute APIs, and we don't want to declare it as legal
17 ///      in the backend, because we want the legalizer to expand all v16i8
18 ///      operations.
19 /// v1* => *
20 ///   - Having v1* types complicates the legalizer and we can easily replace
21 ///   - them with the element type.
22 //===----------------------------------------------------------------------===//
23 
24 #include "AMDGPU.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstVisitor.h"
27 
28 using namespace llvm;
29 
30 namespace {
31 
32 class SITypeRewriter : public FunctionPass,
33                        public InstVisitor<SITypeRewriter> {
34 
35   static char ID;
36   Module *Mod;
37   Type *v16i8;
38   Type *v4i32;
39 
40 public:
SITypeRewriter()41   SITypeRewriter() : FunctionPass(ID) { }
42   bool doInitialization(Module &M) override;
43   bool runOnFunction(Function &F) override;
getPassName() const44   const char *getPassName() const override {
45     return "SI Type Rewriter";
46   }
47   void visitLoadInst(LoadInst &I);
48   void visitCallInst(CallInst &I);
49   void visitBitCast(BitCastInst &I);
50 };
51 
52 } // End anonymous namespace
53 
54 char SITypeRewriter::ID = 0;
55 
doInitialization(Module & M)56 bool SITypeRewriter::doInitialization(Module &M) {
57   Mod = &M;
58   v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
59   v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
60   return false;
61 }
62 
runOnFunction(Function & F)63 bool SITypeRewriter::runOnFunction(Function &F) {
64   Attribute A = F.getFnAttribute("ShaderType");
65 
66   unsigned ShaderType = ShaderType::COMPUTE;
67   if (A.isStringAttribute()) {
68     StringRef Str = A.getValueAsString();
69     Str.getAsInteger(0, ShaderType);
70   }
71   if (ShaderType == ShaderType::COMPUTE)
72     return false;
73 
74   visit(F);
75   visit(F);
76 
77   return false;
78 }
79 
visitLoadInst(LoadInst & I)80 void SITypeRewriter::visitLoadInst(LoadInst &I) {
81   Value *Ptr = I.getPointerOperand();
82   Type *PtrTy = Ptr->getType();
83   Type *ElemTy = PtrTy->getPointerElementType();
84   IRBuilder<> Builder(&I);
85   if (ElemTy == v16i8)  {
86     Value *BitCast = Builder.CreateBitCast(Ptr,
87         PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
88     LoadInst *Load = Builder.CreateLoad(BitCast);
89     SmallVector<std::pair<unsigned, MDNode *>, 8> MD;
90     I.getAllMetadataOtherThanDebugLoc(MD);
91     for (unsigned i = 0, e = MD.size(); i != e; ++i) {
92       Load->setMetadata(MD[i].first, MD[i].second);
93     }
94     Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
95     I.replaceAllUsesWith(BitCastLoad);
96     I.eraseFromParent();
97   }
98 }
99 
visitCallInst(CallInst & I)100 void SITypeRewriter::visitCallInst(CallInst &I) {
101   IRBuilder<> Builder(&I);
102 
103   SmallVector <Value*, 8> Args;
104   SmallVector <Type*, 8> Types;
105   bool NeedToReplace = false;
106   Function *F = I.getCalledFunction();
107   std::string Name = F->getName();
108   for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
109     Value *Arg = I.getArgOperand(i);
110     if (Arg->getType() == v16i8) {
111       Args.push_back(Builder.CreateBitCast(Arg, v4i32));
112       Types.push_back(v4i32);
113       NeedToReplace = true;
114       Name = Name + ".v4i32";
115     } else if (Arg->getType()->isVectorTy() &&
116                Arg->getType()->getVectorNumElements() == 1 &&
117                Arg->getType()->getVectorElementType() ==
118                                               Type::getInt32Ty(I.getContext())){
119       Type *ElementTy = Arg->getType()->getVectorElementType();
120       std::string TypeName = "i32";
121       InsertElementInst *Def = cast<InsertElementInst>(Arg);
122       Args.push_back(Def->getOperand(1));
123       Types.push_back(ElementTy);
124       std::string VecTypeName = "v1" + TypeName;
125       Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
126       NeedToReplace = true;
127     } else {
128       Args.push_back(Arg);
129       Types.push_back(Arg->getType());
130     }
131   }
132 
133   if (!NeedToReplace) {
134     return;
135   }
136   Function *NewF = Mod->getFunction(Name);
137   if (!NewF) {
138     NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
139     NewF->setAttributes(F->getAttributes());
140   }
141   I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
142   I.eraseFromParent();
143 }
144 
visitBitCast(BitCastInst & I)145 void SITypeRewriter::visitBitCast(BitCastInst &I) {
146   IRBuilder<> Builder(&I);
147   if (I.getDestTy() != v4i32) {
148     return;
149   }
150 
151   if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
152     if (Op->getSrcTy() == v4i32) {
153       I.replaceAllUsesWith(Op->getOperand(0));
154       I.eraseFromParent();
155     }
156   }
157 }
158 
createSITypeRewriter()159 FunctionPass *llvm::createSITypeRewriter() {
160   return new SITypeRewriter();
161 }
162