1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
18 
19 #include "llvm/IR/IRBuilder.h"
20 
21 namespace xla {
22 
23 // Mixin class that injects more ergonomic versions of llvm::IRBuilder methods
24 // into a class.  Intended to be used as a CRTP base class, like:
25 //
26 //  class MyIrEmitter : public IrBuilderMixin<MyIrEmitter> {
27 //    llvm::IRBuilder<>* builder() { return builder_; }
28 //
29 //    void EmitFoo(HloInstruction* foo) {
30 //      Add(Mul(...), FPToUI(...));
31 //    }
32 //  };
33 
34 template <typename Derived>
35 class IrBuilderMixin {
36  protected:
37   template <class... Args>
Add(Args &&...args)38   llvm::Value* Add(Args&&... args) {
39     return mixin_builder()->CreateAdd(std::forward<Args>(args)...);
40   }
41 
42   template <class... Args>
AlignedLoad(Args &&...args)43   llvm::LoadInst* AlignedLoad(Args&&... args) {
44     return mixin_builder()->CreateAlignedLoad(std::forward<Args>(args)...);
45   }
46 
47   template <class... Args>
AlignedStore(Args &&...args)48   llvm::StoreInst* AlignedStore(Args&&... args) {
49     return mixin_builder()->CreateAlignedStore(std::forward<Args>(args)...);
50   }
51 
52   template <class... Args>
Alloca(Args &&...args)53   llvm::AllocaInst* Alloca(Args&&... args) {
54     return mixin_builder()->CreateAlloca(std::forward<Args>(args)...);
55   }
56 
57   template <class... Args>
And(Args &&...args)58   llvm::Value* And(Args&&... args) {
59     return mixin_builder()->CreateAnd(std::forward<Args>(args)...);
60   }
61 
62   template <class... Args>
AtomicCmpXchg(Args &&...args)63   llvm::Value* AtomicCmpXchg(Args&&... args) {
64     return mixin_builder()->CreateAtomicCmpXchg(std::forward<Args>(args)...);
65   }
66 
67   template <class... Args>
AtomicRMW(Args &&...args)68   llvm::Value* AtomicRMW(Args&&... args) {
69     return mixin_builder()->CreateAtomicRMW(std::forward<Args>(args)...);
70   }
71 
72   template <class... Args>
BitCast(Args &&...args)73   llvm::Value* BitCast(Args&&... args) {
74     return mixin_builder()->CreateBitCast(std::forward<Args>(args)...);
75   }
76 
77   template <class... Args>
Br(Args &&...args)78   llvm::Value* Br(Args&&... args) {
79     return mixin_builder()->CreateBr(std::forward<Args>(args)...);
80   }
81 
82   llvm::CallInst* Call(llvm::Value* callee,
83                        llvm::ArrayRef<llvm::Value*> args = llvm::None,
84                        const llvm::Twine& name = "",
85                        llvm::MDNode* fp_math_tag = nullptr) {
86     return mixin_builder()->CreateCall(callee, args, name, fp_math_tag);
87   }
88 
89   template <class... Args>
CondBr(Args &&...args)90   llvm::BranchInst* CondBr(Args&&... args) {
91     return mixin_builder()->CreateCondBr(std::forward<Args>(args)...);
92   }
93 
94   template <class... Args>
ConstInBoundsGEP1_32(Args &&...args)95   llvm::Value* ConstInBoundsGEP1_32(Args&&... args) {
96     return mixin_builder()->CreateConstInBoundsGEP1_32(
97         std::forward<Args>(args)...);
98   }
99 
100   template <class... Args>
FAdd(Args &&...args)101   llvm::Value* FAdd(Args&&... args) {
102     return mixin_builder()->CreateFAdd(std::forward<Args>(args)...);
103   }
104 
105   template <class... Args>
FMul(Args &&...args)106   llvm::Value* FMul(Args&&... args) {
107     return mixin_builder()->CreateFMul(std::forward<Args>(args)...);
108   }
109 
110   llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef<llvm::Value*> idx_list,
111                    const llvm::Twine& name = "") {
112     return mixin_builder()->CreateGEP(ptr, idx_list, name);
113   }
114 
115   template <class... Args>
ICmpEQ(Args &&...args)116   llvm::Value* ICmpEQ(Args&&... args) {
117     return mixin_builder()->CreateICmpEQ(std::forward<Args>(args)...);
118   }
119 
120   template <class... Args>
ICmpNE(Args &&...args)121   llvm::Value* ICmpNE(Args&&... args) {
122     return mixin_builder()->CreateICmpNE(std::forward<Args>(args)...);
123   }
124 
125   template <class... Args>
ICmpULE(Args &&...args)126   llvm::Value* ICmpULE(Args&&... args) {
127     return mixin_builder()->CreateICmpULE(std::forward<Args>(args)...);
128   }
129 
130   template <class... Args>
ICmpULT(Args &&...args)131   llvm::Value* ICmpULT(Args&&... args) {
132     return mixin_builder()->CreateICmpULT(std::forward<Args>(args)...);
133   }
134 
135   llvm::Value* InBoundsGEP(llvm::Value* ptr,
136                            llvm::ArrayRef<llvm::Value*> idx_list,
137                            const llvm::Twine& name = "") {
138     return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name);
139   }
140 
141   llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef<unsigned> idxs,
142                             const llvm::Twine& name = "") {
143     return mixin_builder()->CreateExtractValue(agg, idxs, name);
144   }
145 
146   llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val,
147                            llvm::ArrayRef<unsigned> idxs,
148                            const llvm::Twine& name = "") {
149     return mixin_builder()->CreateInsertValue(agg, val, idxs, name);
150   }
151 
152   template <class... Args>
IntToPtr(Args &&...args)153   llvm::Value* IntToPtr(Args&&... args) {
154     return mixin_builder()->CreateIntToPtr(std::forward<Args>(args)...);
155   }
156 
157   template <class... Args>
Load(Args &&...args)158   llvm::LoadInst* Load(Args&&... args) {
159     return mixin_builder()->CreateLoad(std::forward<Args>(args)...);
160   }
161 
162   template <class... Args>
MemCpy(Args &&...args)163   llvm::CallInst* MemCpy(Args&&... args) {
164     return mixin_builder()->CreateMemCpy(std::forward<Args>(args)...);
165   }
166 
167   template <class... Args>
Mul(Args &&...args)168   llvm::Value* Mul(Args&&... args) {
169     return mixin_builder()->CreateMul(std::forward<Args>(args)...);
170   }
171 
172   template <class... Args>
NSWAdd(Args &&...args)173   llvm::Value* NSWAdd(Args&&... args) {
174     return mixin_builder()->CreateNSWAdd(std::forward<Args>(args)...);
175   }
176 
177   template <class... Args>
NSWMul(Args &&...args)178   llvm::Value* NSWMul(Args&&... args) {
179     return mixin_builder()->CreateNSWMul(std::forward<Args>(args)...);
180   }
181 
182   template <class... Args>
NSWSub(Args &&...args)183   llvm::Value* NSWSub(Args&&... args) {
184     return mixin_builder()->CreateNSWSub(std::forward<Args>(args)...);
185   }
186 
187   template <class... Args>
Or(Args &&...args)188   llvm::Value* Or(Args&&... args) {
189     return mixin_builder()->CreateOr(std::forward<Args>(args)...);
190   }
191 
192   template <class... Args>
PointerCast(Args &&...args)193   llvm::Value* PointerCast(Args&&... args) {
194     return mixin_builder()->CreatePointerCast(std::forward<Args>(args)...);
195   }
196 
197   template <class... Args>
PtrToInt(Args &&...args)198   llvm::Value* PtrToInt(Args&&... args) {
199     return mixin_builder()->CreatePtrToInt(std::forward<Args>(args)...);
200   }
201 
202   template <class... Args>
SDiv(Args &&...args)203   llvm::Value* SDiv(Args&&... args) {
204     return mixin_builder()->CreateSDiv(std::forward<Args>(args)...);
205   }
206 
207   template <class... Args>
Select(Args &&...args)208   llvm::Value* Select(Args&&... args) {
209     return mixin_builder()->CreateSelect(std::forward<Args>(args)...);
210   }
211 
212   template <class... Args>
SRem(Args &&...args)213   llvm::Value* SRem(Args&&... args) {
214     return mixin_builder()->CreateSRem(std::forward<Args>(args)...);
215   }
216 
217   template <class... Args>
Store(Args &&...args)218   llvm::StoreInst* Store(Args&&... args) {
219     return mixin_builder()->CreateStore(std::forward<Args>(args)...);
220   }
221 
222   template <class... Args>
UDiv(Args &&...args)223   llvm::Value* UDiv(Args&&... args) {
224     return mixin_builder()->CreateUDiv(std::forward<Args>(args)...);
225   }
226 
227   template <class... Args>
URem(Args &&...args)228   llvm::Value* URem(Args&&... args) {
229     return mixin_builder()->CreateURem(std::forward<Args>(args)...);
230   }
231 
232   template <class... Args>
VectorSplat(Args &&...args)233   llvm::Value* VectorSplat(Args&&... args) {
234     return mixin_builder()->CreateVectorSplat(std::forward<Args>(args)...);
235   }
236 
237   template <class... Args>
ZExtOrTrunc(Args &&...args)238   llvm::Value* ZExtOrTrunc(Args&&... args) {
239     return mixin_builder()->CreateZExtOrTrunc(std::forward<Args>(args)...);
240   }
241 
242   template <class... Args>
AShr(Args &&...args)243   llvm::Value* AShr(Args&&... args) {
244     return mixin_builder()->CreateAShr(std::forward<Args>(args)...);
245   }
246 
247   template <class... Args>
FCmpOEQ(Args &&...args)248   llvm::Value* FCmpOEQ(Args&&... args) {
249     return mixin_builder()->CreateFCmpOEQ(std::forward<Args>(args)...);
250   }
251 
252   template <class... Args>
FCmpOLT(Args &&...args)253   llvm::Value* FCmpOLT(Args&&... args) {
254     return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...);
255   }
256 
257   template <class... Args>
FCmpOLE(Args &&...args)258   llvm::Value* FCmpOLE(Args&&... args) {
259     return mixin_builder()->CreateFCmpOLE(std::forward<Args>(args)...);
260   }
261 
262   template <class... Args>
FCmpONE(Args &&...args)263   llvm::Value* FCmpONE(Args&&... args) {
264     return mixin_builder()->CreateFCmpONE(std::forward<Args>(args)...);
265   }
266 
267   template <class... Args>
FCmpUNE(Args &&...args)268   llvm::Value* FCmpUNE(Args&&... args) {
269     return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...);
270   }
271 
272   template <class... Args>
FCmpUNO(Args &&...args)273   llvm::Value* FCmpUNO(Args&&... args) {
274     return mixin_builder()->CreateFCmpUNO(std::forward<Args>(args)...);
275   }
276 
277   template <class... Args>
FDiv(Args &&...args)278   llvm::Value* FDiv(Args&&... args) {
279     return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);
280   }
281 
282   template <class... Args>
FNeg(Args &&...args)283   llvm::Value* FNeg(Args&&... args) {
284     return mixin_builder()->CreateFNeg(std::forward<Args>(args)...);
285   }
286 
287   template <class... Args>
FPCast(Args &&...args)288   llvm::Value* FPCast(Args&&... args) {
289     return mixin_builder()->CreateFPCast(std::forward<Args>(args)...);
290   }
291 
292   template <class... Args>
FPToSI(Args &&...args)293   llvm::Value* FPToSI(Args&&... args) {
294     return mixin_builder()->CreateFPToSI(std::forward<Args>(args)...);
295   }
296 
297   template <class... Args>
FPToUI(Args &&...args)298   llvm::Value* FPToUI(Args&&... args) {
299     return mixin_builder()->CreateFPToUI(std::forward<Args>(args)...);
300   }
301 
302   template <class... Args>
FPTrunc(Args &&...args)303   llvm::Value* FPTrunc(Args&&... args) {
304     return mixin_builder()->CreateFPTrunc(std::forward<Args>(args)...);
305   }
306 
307   template <class... Args>
FRem(Args &&...args)308   llvm::Value* FRem(Args&&... args) {
309     return mixin_builder()->CreateFRem(std::forward<Args>(args)...);
310   }
311 
312   template <class... Args>
FSub(Args &&...args)313   llvm::Value* FSub(Args&&... args) {
314     return mixin_builder()->CreateFSub(std::forward<Args>(args)...);
315   }
316 
317   template <class... Args>
ICmpSGE(Args &&...args)318   llvm::Value* ICmpSGE(Args&&... args) {
319     return mixin_builder()->CreateICmpSGE(std::forward<Args>(args)...);
320   }
321 
322   template <class... Args>
ICmpSLT(Args &&...args)323   llvm::Value* ICmpSLT(Args&&... args) {
324     return mixin_builder()->CreateICmpSLT(std::forward<Args>(args)...);
325   }
326 
327   template <class... Args>
IntCast(Args &&...args)328   llvm::Value* IntCast(Args&&... args) {
329     return mixin_builder()->CreateIntCast(std::forward<Args>(args)...);
330   }
331 
332   template <class... Args>
LShr(Args &&...args)333   llvm::Value* LShr(Args&&... args) {
334     return mixin_builder()->CreateLShr(std::forward<Args>(args)...);
335   }
336 
337   template <class... Args>
MemSet(Args &&...args)338   llvm::Value* MemSet(Args&&... args) {
339     return mixin_builder()->CreateMemSet(std::forward<Args>(args)...);
340   }
341 
342   template <class... Args>
Neg(Args &&...args)343   llvm::Value* Neg(Args&&... args) {
344     return mixin_builder()->CreateNeg(std::forward<Args>(args)...);
345   }
346 
347   template <class... Args>
Not(Args &&...args)348   llvm::Value* Not(Args&&... args) {
349     return mixin_builder()->CreateNot(std::forward<Args>(args)...);
350   }
351 
352   template <class... Args>
PHI(Args &&...args)353   llvm::PHINode* PHI(Args&&... args) {
354     return mixin_builder()->CreatePHI(std::forward<Args>(args)...);
355   }
356 
357   template <class... Args>
RetVoid(Args &&...args)358   llvm::Value* RetVoid(Args&&... args) {
359     return mixin_builder()->CreateRetVoid(std::forward<Args>(args)...);
360   }
361 
362   template <class... Args>
SExtOrTrunc(Args &&...args)363   llvm::Value* SExtOrTrunc(Args&&... args) {
364     return mixin_builder()->CreateSExtOrTrunc(std::forward<Args>(args)...);
365   }
366 
367   template <class... Args>
Shl(Args &&...args)368   llvm::Value* Shl(Args&&... args) {
369     return mixin_builder()->CreateShl(std::forward<Args>(args)...);
370   }
371 
372   template <class... Args>
SIToFP(Args &&...args)373   llvm::Value* SIToFP(Args&&... args) {
374     return mixin_builder()->CreateSIToFP(std::forward<Args>(args)...);
375   }
376 
377   template <class... Args>
Sub(Args &&...args)378   llvm::Value* Sub(Args&&... args) {
379     return mixin_builder()->CreateSub(std::forward<Args>(args)...);
380   }
381 
382   template <class... Args>
Trunc(Args &&...args)383   llvm::Value* Trunc(Args&&... args) {
384     return mixin_builder()->CreateTrunc(std::forward<Args>(args)...);
385   }
386 
387   template <class... Args>
UIToFP(Args &&...args)388   llvm::Value* UIToFP(Args&&... args) {
389     return mixin_builder()->CreateUIToFP(std::forward<Args>(args)...);
390   }
391 
392   template <class... Args>
Unreachable(Args &&...args)393   llvm::Value* Unreachable(Args&&... args) {
394     return mixin_builder()->CreateUnreachable(std::forward<Args>(args)...);
395   }
396 
397   template <class... Args>
Xor(Args &&...args)398   llvm::Value* Xor(Args&&... args) {
399     return mixin_builder()->CreateXor(std::forward<Args>(args)...);
400   }
401 
402  private:
mixin_builder()403   llvm::IRBuilder<>* mixin_builder() {
404     return static_cast<Derived*>(this)->builder();
405   }
406 };
407 
408 }  // namespace xla
409 
410 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
411