1 //===- SPIRVLowerBool.cpp � Lower instructions with bool operands ----------===//
2 //
3 //                     The LLVM/SPIRV Translator
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining a
11 // copy of this software and associated documentation files (the "Software"),
12 // to deal with the Software without restriction, including without limitation
13 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
14 // and/or sell copies of the Software, and to permit persons to whom the
15 // Software is furnished to do so, subject to the following conditions:
16 //
17 // Redistributions of source code must retain the above copyright notice,
18 // this list of conditions and the following disclaimers.
19 // Redistributions in binary form must reproduce the above copyright notice,
20 // this list of conditions and the following disclaimers in the documentation
21 // and/or other materials provided with the distribution.
22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its
23 // contributors may be used to endorse or promote products derived from this
24 // Software without specific prior written permission.
25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31 // THE SOFTWARE.
32 //
33 //===----------------------------------------------------------------------===//
34 //
35 // This file implements lowering instructions with bool operands.
36 //
37 //===----------------------------------------------------------------------===//
38 #define DEBUG_TYPE "spvbool"
39 
40 #include "SPIRVInternal.h"
41 #include "llvm/IR/InstVisitor.h"
42 #include "llvm/IR/Instructions.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Verifier.h"
45 #include "llvm/Pass.h"
46 #include "llvm/PassSupport.h"
47 #include "llvm/Support/CommandLine.h"
48 #include "llvm/Support/Debug.h"
49 #include "llvm/Support/raw_ostream.h"
50 
51 using namespace llvm;
52 using namespace SPIRV;
53 
54 namespace SPIRV {
55 cl::opt<bool> SPIRVLowerBoolValidate("spvbool-validate",
56     cl::desc("Validate module after lowering boolean instructions for SPIR-V"));
57 
58 class SPIRVLowerBool: public ModulePass,
59   public InstVisitor<SPIRVLowerBool> {
60 public:
SPIRVLowerBool()61   SPIRVLowerBool():ModulePass(ID), Context(nullptr) {
62     initializeSPIRVLowerBoolPass(*PassRegistry::getPassRegistry());
63   }
replace(Instruction * I,Instruction * NewI)64   void replace(Instruction *I, Instruction *NewI) {
65     NewI->takeName(I);
66     I->replaceAllUsesWith(NewI);
67     I->dropAllReferences();
68     I->eraseFromParent();
69   }
isBoolType(Type * Ty)70   bool isBoolType(Type *Ty) {
71     if (Ty->isIntegerTy(1))
72       return true;
73     if (auto VT = dyn_cast<VectorType>(Ty))
74       return isBoolType(VT->getElementType());
75     return false;
76   }
visitTruncInst(TruncInst & I)77   virtual void visitTruncInst(TruncInst &I) {
78     if (isBoolType(I.getType())) {
79       auto Op = I.getOperand(0);
80       auto Zero = getScalarOrVectorConstantInt(Op->getType(), 0, false);
81       auto Cmp = new ICmpInst(&I, CmpInst::ICMP_NE, Op, Zero);
82       replace(&I, Cmp);
83     }
84   }
visitZExtInst(ZExtInst & I)85   virtual void visitZExtInst(ZExtInst &I) {
86     auto Op = I.getOperand(0);
87     if (isBoolType(Op->getType())) {
88       auto Ty = I.getType();
89       auto Zero = getScalarOrVectorConstantInt(Ty, 0, false);
90       auto One = getScalarOrVectorConstantInt(Ty, 1, false);
91       auto Sel = SelectInst::Create(Op, One, Zero, "", &I);
92       replace(&I, Sel);
93     }
94   }
visitSExtInst(SExtInst & I)95   virtual void visitSExtInst(SExtInst &I) {
96     auto Op = I.getOperand(0);
97     if (isBoolType(Op->getType())) {
98       auto Ty = I.getType();
99       auto Zero = getScalarOrVectorConstantInt(Ty, 0, false);
100       auto One = getScalarOrVectorConstantInt(Ty, ~0, false);
101       auto Sel = SelectInst::Create(Op, One, Zero, "", &I);
102       replace(&I, Sel);
103     }
104   }
runOnModule(Module & M)105   virtual bool runOnModule(Module &M) {
106     Context = &M.getContext();
107     visit(M);
108 
109     if (SPIRVLowerBoolValidate) {
110       DEBUG(dbgs() << "After SPIRVLowerBool:\n" << M);
111       std::string Err;
112       raw_string_ostream ErrorOS(Err);
113       if (verifyModule(M, &ErrorOS)){
114         Err = std::string("Fails to verify module: ") + Err;
115         report_fatal_error(Err.c_str(), false);
116       }
117     }
118     return true;
119   }
120 
121   static char ID;
122 private:
123   LLVMContext *Context;
124 };
125 
126 char SPIRVLowerBool::ID = 0;
127 }
128 
129 INITIALIZE_PASS(SPIRVLowerBool, "spvbool",
130     "Lower instructions with bool operands", false, false)
131 
createSPIRVLowerBool()132 ModulePass *llvm::createSPIRVLowerBool() {
133   return new SPIRVLowerBool();
134 }
135