1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
18 
19 #include "llvm/IR/IRBuilder.h"
20 
21 namespace xla {
22 
23 // Mixin class that injects more ergonomic versions of llvm::IRBuilder methods
24 // into a class.  Intended to be used as a CRTP base class, like:
25 //
26 //  class MyIrEmitter : public IrBuilderMixin<MyIrEmitter> {
27 //    llvm::IRBuilder<>* builder() { return builder_; }
28 //
29 //    void EmitFoo(HloInstruction* foo) {
30 //      Add(Mul(...), FPToUI(...));
31 //    }
32 //  };
33 
34 template <typename Derived>
35 class IrBuilderMixin {
36  protected:
37   template <class... Args>
Add(Args &&...args)38   llvm::Value* Add(Args&&... args) {
39     return mixin_builder()->CreateAdd(std::forward<Args>(args)...);
40   }
41 
42   template <class... Args>
AlignedLoad(Args &&...args)43   llvm::LoadInst* AlignedLoad(Args&&... args) {
44     return mixin_builder()->CreateAlignedLoad(std::forward<Args>(args)...);
45   }
46 
47   template <class... Args>
AlignedStore(Args &&...args)48   llvm::StoreInst* AlignedStore(Args&&... args) {
49     return mixin_builder()->CreateAlignedStore(std::forward<Args>(args)...);
50   }
51 
52   template <class... Args>
Alloca(Args &&...args)53   llvm::AllocaInst* Alloca(Args&&... args) {
54     return mixin_builder()->CreateAlloca(std::forward<Args>(args)...);
55   }
56 
57   template <class... Args>
And(Args &&...args)58   llvm::Value* And(Args&&... args) {
59     return mixin_builder()->CreateAnd(std::forward<Args>(args)...);
60   }
61 
62   template <class... Args>
AtomicCmpXchg(Args &&...args)63   llvm::Value* AtomicCmpXchg(Args&&... args) {
64     return mixin_builder()->CreateAtomicCmpXchg(std::forward<Args>(args)...);
65   }
66 
67   template <class... Args>
AtomicRMW(Args &&...args)68   llvm::Value* AtomicRMW(Args&&... args) {
69     return mixin_builder()->CreateAtomicRMW(std::forward<Args>(args)...);
70   }
71 
72   template <class... Args>
BitCast(Args &&...args)73   llvm::Value* BitCast(Args&&... args) {
74     return mixin_builder()->CreateBitCast(std::forward<Args>(args)...);
75   }
76 
77   template <class... Args>
Br(Args &&...args)78   llvm::Value* Br(Args&&... args) {
79     return mixin_builder()->CreateBr(std::forward<Args>(args)...);
80   }
81 
82   llvm::CallInst* Call(llvm::FunctionCallee func_callee,
83                        llvm::ArrayRef<llvm::Value*> args = llvm::None,
84                        const llvm::Twine& name = "",
85                        llvm::MDNode* fp_math_tag = nullptr) {
86     return mixin_builder()->CreateCall(func_callee, args, name, fp_math_tag);
87   }
88 
89   llvm::CallInst* Call(llvm::FunctionType* func_type, llvm::Value* callee,
90                        llvm::ArrayRef<llvm::Value*> args = llvm::None,
91                        const llvm::Twine& name = "",
92                        llvm::MDNode* fp_math_tag = nullptr) {
93     return mixin_builder()->CreateCall(func_type, callee, args, name,
94                                        fp_math_tag);
95   }
96 
97   template <class... Args>
CondBr(Args &&...args)98   llvm::BranchInst* CondBr(Args&&... args) {
99     return mixin_builder()->CreateCondBr(std::forward<Args>(args)...);
100   }
101 
102   template <class... Args>
ConstInBoundsGEP1_32(Args &&...args)103   llvm::Value* ConstInBoundsGEP1_32(Args&&... args) {
104     return mixin_builder()->CreateConstInBoundsGEP1_32(
105         std::forward<Args>(args)...);
106   }
107 
108   template <class... Args>
FAdd(Args &&...args)109   llvm::Value* FAdd(Args&&... args) {
110     return mixin_builder()->CreateFAdd(std::forward<Args>(args)...);
111   }
112 
113   template <class... Args>
FMul(Args &&...args)114   llvm::Value* FMul(Args&&... args) {
115     return mixin_builder()->CreateFMul(std::forward<Args>(args)...);
116   }
117 
118   llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef<llvm::Value*> idx_list,
119                    const llvm::Twine& name = "") {
120     return mixin_builder()->CreateGEP(ptr, idx_list, name);
121   }
122 
123   template <class... Args>
ICmpEQ(Args &&...args)124   llvm::Value* ICmpEQ(Args&&... args) {
125     return mixin_builder()->CreateICmpEQ(std::forward<Args>(args)...);
126   }
127 
128   template <class... Args>
ICmpNE(Args &&...args)129   llvm::Value* ICmpNE(Args&&... args) {
130     return mixin_builder()->CreateICmpNE(std::forward<Args>(args)...);
131   }
132 
133   template <class... Args>
ICmpULE(Args &&...args)134   llvm::Value* ICmpULE(Args&&... args) {
135     return mixin_builder()->CreateICmpULE(std::forward<Args>(args)...);
136   }
137 
138   template <class... Args>
ICmpULT(Args &&...args)139   llvm::Value* ICmpULT(Args&&... args) {
140     return mixin_builder()->CreateICmpULT(std::forward<Args>(args)...);
141   }
142 
143   llvm::Value* InBoundsGEP(llvm::Value* ptr,
144                            llvm::ArrayRef<llvm::Value*> idx_list,
145                            const llvm::Twine& name = "") {
146     return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name);
147   }
148 
149   llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef<unsigned> idxs,
150                             const llvm::Twine& name = "") {
151     return mixin_builder()->CreateExtractValue(agg, idxs, name);
152   }
153 
154   llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val,
155                            llvm::ArrayRef<unsigned> idxs,
156                            const llvm::Twine& name = "") {
157     return mixin_builder()->CreateInsertValue(agg, val, idxs, name);
158   }
159 
160   template <class... Args>
IntToPtr(Args &&...args)161   llvm::Value* IntToPtr(Args&&... args) {
162     return mixin_builder()->CreateIntToPtr(std::forward<Args>(args)...);
163   }
164 
165   template <class... Args>
Load(Args &&...args)166   llvm::LoadInst* Load(Args&&... args) {
167     return mixin_builder()->CreateLoad(std::forward<Args>(args)...);
168   }
169 
170   template <class... Args>
MemCpy(Args &&...args)171   llvm::CallInst* MemCpy(Args&&... args) {
172     return mixin_builder()->CreateMemCpy(std::forward<Args>(args)...);
173   }
174 
175   template <class... Args>
Mul(Args &&...args)176   llvm::Value* Mul(Args&&... args) {
177     return mixin_builder()->CreateMul(std::forward<Args>(args)...);
178   }
179 
180   template <class... Args>
NSWAdd(Args &&...args)181   llvm::Value* NSWAdd(Args&&... args) {
182     return mixin_builder()->CreateNSWAdd(std::forward<Args>(args)...);
183   }
184 
185   template <class... Args>
NSWMul(Args &&...args)186   llvm::Value* NSWMul(Args&&... args) {
187     return mixin_builder()->CreateNSWMul(std::forward<Args>(args)...);
188   }
189 
190   template <class... Args>
NSWSub(Args &&...args)191   llvm::Value* NSWSub(Args&&... args) {
192     return mixin_builder()->CreateNSWSub(std::forward<Args>(args)...);
193   }
194 
195   template <class... Args>
Or(Args &&...args)196   llvm::Value* Or(Args&&... args) {
197     return mixin_builder()->CreateOr(std::forward<Args>(args)...);
198   }
199 
200   template <class... Args>
PointerCast(Args &&...args)201   llvm::Value* PointerCast(Args&&... args) {
202     return mixin_builder()->CreatePointerCast(std::forward<Args>(args)...);
203   }
204 
205   template <class... Args>
PtrToInt(Args &&...args)206   llvm::Value* PtrToInt(Args&&... args) {
207     return mixin_builder()->CreatePtrToInt(std::forward<Args>(args)...);
208   }
209 
210   template <class... Args>
SDiv(Args &&...args)211   llvm::Value* SDiv(Args&&... args) {
212     return mixin_builder()->CreateSDiv(std::forward<Args>(args)...);
213   }
214 
215   template <class... Args>
Select(Args &&...args)216   llvm::Value* Select(Args&&... args) {
217     return mixin_builder()->CreateSelect(std::forward<Args>(args)...);
218   }
219 
220   template <class... Args>
SRem(Args &&...args)221   llvm::Value* SRem(Args&&... args) {
222     return mixin_builder()->CreateSRem(std::forward<Args>(args)...);
223   }
224 
225   template <class... Args>
Store(Args &&...args)226   llvm::StoreInst* Store(Args&&... args) {
227     return mixin_builder()->CreateStore(std::forward<Args>(args)...);
228   }
229 
230   template <class... Args>
UDiv(Args &&...args)231   llvm::Value* UDiv(Args&&... args) {
232     return mixin_builder()->CreateUDiv(std::forward<Args>(args)...);
233   }
234 
235   template <class... Args>
URem(Args &&...args)236   llvm::Value* URem(Args&&... args) {
237     return mixin_builder()->CreateURem(std::forward<Args>(args)...);
238   }
239 
240   template <class... Args>
VectorSplat(Args &&...args)241   llvm::Value* VectorSplat(Args&&... args) {
242     return mixin_builder()->CreateVectorSplat(std::forward<Args>(args)...);
243   }
244 
245   template <class... Args>
ZExtOrTrunc(Args &&...args)246   llvm::Value* ZExtOrTrunc(Args&&... args) {
247     return mixin_builder()->CreateZExtOrTrunc(std::forward<Args>(args)...);
248   }
249 
250   template <class... Args>
AShr(Args &&...args)251   llvm::Value* AShr(Args&&... args) {
252     return mixin_builder()->CreateAShr(std::forward<Args>(args)...);
253   }
254 
255   template <class... Args>
FCmpOEQ(Args &&...args)256   llvm::Value* FCmpOEQ(Args&&... args) {
257     return mixin_builder()->CreateFCmpOEQ(std::forward<Args>(args)...);
258   }
259 
260   template <class... Args>
FCmpOGT(Args &&...args)261   llvm::Value* FCmpOGT(Args&&... args) {
262     return mixin_builder()->CreateFCmpOGT(std::forward<Args>(args)...);
263   }
264 
265   template <class... Args>
FCmpOGE(Args &&...args)266   llvm::Value* FCmpOGE(Args&&... args) {
267     return mixin_builder()->CreateFCmpOGE(std::forward<Args>(args)...);
268   }
269 
270   template <class... Args>
FCmpOLT(Args &&...args)271   llvm::Value* FCmpOLT(Args&&... args) {
272     return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...);
273   }
274 
275   template <class... Args>
FCmpULT(Args &&...args)276   llvm::Value* FCmpULT(Args&&... args) {
277     return mixin_builder()->CreateFCmpULT(std::forward<Args>(args)...);
278   }
279 
280   template <class... Args>
FCmpOLE(Args &&...args)281   llvm::Value* FCmpOLE(Args&&... args) {
282     return mixin_builder()->CreateFCmpOLE(std::forward<Args>(args)...);
283   }
284 
285   template <class... Args>
FCmpONE(Args &&...args)286   llvm::Value* FCmpONE(Args&&... args) {
287     return mixin_builder()->CreateFCmpONE(std::forward<Args>(args)...);
288   }
289 
290   template <class... Args>
FCmpUNE(Args &&...args)291   llvm::Value* FCmpUNE(Args&&... args) {
292     return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...);
293   }
294 
295   template <class... Args>
FCmpUNO(Args &&...args)296   llvm::Value* FCmpUNO(Args&&... args) {
297     return mixin_builder()->CreateFCmpUNO(std::forward<Args>(args)...);
298   }
299 
300   template <class... Args>
FDiv(Args &&...args)301   llvm::Value* FDiv(Args&&... args) {
302     return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);
303   }
304 
305   template <class... Args>
FNeg(Args &&...args)306   llvm::Value* FNeg(Args&&... args) {
307     return mixin_builder()->CreateFNeg(std::forward<Args>(args)...);
308   }
309 
310   template <class... Args>
FPCast(Args &&...args)311   llvm::Value* FPCast(Args&&... args) {
312     return mixin_builder()->CreateFPCast(std::forward<Args>(args)...);
313   }
314 
315   template <class... Args>
FPToSI(Args &&...args)316   llvm::Value* FPToSI(Args&&... args) {
317     return mixin_builder()->CreateFPToSI(std::forward<Args>(args)...);
318   }
319 
320   template <class... Args>
FPToUI(Args &&...args)321   llvm::Value* FPToUI(Args&&... args) {
322     return mixin_builder()->CreateFPToUI(std::forward<Args>(args)...);
323   }
324 
325   template <class... Args>
FPTrunc(Args &&...args)326   llvm::Value* FPTrunc(Args&&... args) {
327     return mixin_builder()->CreateFPTrunc(std::forward<Args>(args)...);
328   }
329 
330   template <class... Args>
FRem(Args &&...args)331   llvm::Value* FRem(Args&&... args) {
332     return mixin_builder()->CreateFRem(std::forward<Args>(args)...);
333   }
334 
335   template <class... Args>
FSub(Args &&...args)336   llvm::Value* FSub(Args&&... args) {
337     return mixin_builder()->CreateFSub(std::forward<Args>(args)...);
338   }
339 
340   template <class... Args>
ICmpSGE(Args &&...args)341   llvm::Value* ICmpSGE(Args&&... args) {
342     return mixin_builder()->CreateICmpSGE(std::forward<Args>(args)...);
343   }
344 
345   template <class... Args>
ICmpSLT(Args &&...args)346   llvm::Value* ICmpSLT(Args&&... args) {
347     return mixin_builder()->CreateICmpSLT(std::forward<Args>(args)...);
348   }
349 
350   template <class... Args>
IntCast(Args &&...args)351   llvm::Value* IntCast(Args&&... args) {
352     return mixin_builder()->CreateIntCast(std::forward<Args>(args)...);
353   }
354 
355   template <class... Args>
LShr(Args &&...args)356   llvm::Value* LShr(Args&&... args) {
357     return mixin_builder()->CreateLShr(std::forward<Args>(args)...);
358   }
359 
360   template <class... Args>
MemSet(Args &&...args)361   llvm::Value* MemSet(Args&&... args) {
362     return mixin_builder()->CreateMemSet(std::forward<Args>(args)...);
363   }
364 
365   template <class... Args>
Neg(Args &&...args)366   llvm::Value* Neg(Args&&... args) {
367     return mixin_builder()->CreateNeg(std::forward<Args>(args)...);
368   }
369 
370   template <class... Args>
Not(Args &&...args)371   llvm::Value* Not(Args&&... args) {
372     return mixin_builder()->CreateNot(std::forward<Args>(args)...);
373   }
374 
375   template <class... Args>
PHI(Args &&...args)376   llvm::PHINode* PHI(Args&&... args) {
377     return mixin_builder()->CreatePHI(std::forward<Args>(args)...);
378   }
379 
380   template <class... Args>
RetVoid(Args &&...args)381   llvm::Value* RetVoid(Args&&... args) {
382     return mixin_builder()->CreateRetVoid(std::forward<Args>(args)...);
383   }
384 
385   template <class... Args>
SExtOrTrunc(Args &&...args)386   llvm::Value* SExtOrTrunc(Args&&... args) {
387     return mixin_builder()->CreateSExtOrTrunc(std::forward<Args>(args)...);
388   }
389 
390   template <class... Args>
Shl(Args &&...args)391   llvm::Value* Shl(Args&&... args) {
392     return mixin_builder()->CreateShl(std::forward<Args>(args)...);
393   }
394 
395   template <class... Args>
SIToFP(Args &&...args)396   llvm::Value* SIToFP(Args&&... args) {
397     return mixin_builder()->CreateSIToFP(std::forward<Args>(args)...);
398   }
399 
400   template <class... Args>
Sub(Args &&...args)401   llvm::Value* Sub(Args&&... args) {
402     return mixin_builder()->CreateSub(std::forward<Args>(args)...);
403   }
404 
405   template <class... Args>
Trunc(Args &&...args)406   llvm::Value* Trunc(Args&&... args) {
407     return mixin_builder()->CreateTrunc(std::forward<Args>(args)...);
408   }
409 
410   template <class... Args>
UIToFP(Args &&...args)411   llvm::Value* UIToFP(Args&&... args) {
412     return mixin_builder()->CreateUIToFP(std::forward<Args>(args)...);
413   }
414 
415   template <class... Args>
Unreachable(Args &&...args)416   llvm::Value* Unreachable(Args&&... args) {
417     return mixin_builder()->CreateUnreachable(std::forward<Args>(args)...);
418   }
419 
420   template <class... Args>
Xor(Args &&...args)421   llvm::Value* Xor(Args&&... args) {
422     return mixin_builder()->CreateXor(std::forward<Args>(args)...);
423   }
424 
425  private:
mixin_builder()426   llvm::IRBuilder<>* mixin_builder() {
427     return static_cast<Derived*>(this)->builder();
428   }
429 };
430 
431 }  // namespace xla
432 
433 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_
434