1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
18 
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/IR/Value.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 
34 namespace xla {
35 
36 class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
37  public:
38   using HloToElementGeneratorMap =
39       std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
40 
ElementalIrEmitter(llvm::Module * module,llvm::IRBuilder<> * b)41   ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b)
42       : b_(b), module_(module) {}
43 
44   virtual ~ElementalIrEmitter() = default;
45 
46   // Returns a function to generate an element of the output of `hlo`, given a
47   // map of functions to generate elements of its operands.
48   llvm_ir::ElementGenerator MakeElementGenerator(
49       const HloInstruction* hlo,
50       const HloToElementGeneratorMap& operand_to_generator);
51 
b()52   llvm::IRBuilder<>* b() { return b_; }
53 
54   // builder() is for IrBuilderMixin.
builder()55   llvm::IRBuilder<>* builder() { return b_; }
56 
module()57   llvm::Module* module() { return module_; }
58 
59  protected:
60   virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op,
61                                                    llvm::Value* lhs_value,
62                                                    llvm::Value* rhs_value);
63 
64   virtual llvm::Value* EmitExtractReal(llvm::Value* value);
65   virtual llvm::Value* EmitExtractImag(llvm::Value* value);
66 
67  private:
68   virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
69                                              llvm::Value* operand_value);
70 
71   virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
72                                               llvm::Value* lhs_value,
73                                               llvm::Value* rhs_value);
74 
75   virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op,
76                                                     llvm::Value* operand_value);
77 
78   virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op,
79                                                   llvm::Value* operand_value);
80 
81   virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op,
82                                                     llvm::Value* operand_value);
83 
84   llvm::Value* IsZero(llvm::Value* v);
85   llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs);
86   llvm::Value* GetZero(llvm::Type* type);
87   llvm::Value* GetOne(llvm::Type* type);
88   llvm::Value* GetIntSMin(llvm::Type* type);
89   llvm::Value* GetMinusOne(llvm::Type* type);
90 
91   llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs,
92                                  bool is_signed);
93   llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs,
94                                     bool is_signed);
95   llvm::Value* EmitIntegerPow(llvm::Value* lhs, llvm::Value* rhs,
96                               bool is_signed);
97 
98   virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
99                                                      llvm::Value* lhs_value,
100                                                      llvm::Value* rhs_value,
101                                                      bool is_signed);
102 
103   virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op,
104                                                      llvm::Value* lhs_value,
105                                                      llvm::Value* rhs_value);
106 
107   virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
108                                     llvm::Value* rhs_value,
109                                     absl::string_view name);
110 
111   virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
112                                     llvm::Value* rhs_value,
113                                     absl::string_view name);
114 
115   llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
116                                bool is_signed);
117 
118   llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
119                                bool is_signed);
120 
121   virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
122                                            llvm::Value* lhs, llvm::Value* rhs,
123                                            absl::string_view name);
124 
125   virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
126                                          llvm::Value* value);
127 
128   virtual StatusOr<llvm::Value*> EmitSqrt(PrimitiveType prim_type,
129                                           llvm::Value* value);
130 
131   virtual StatusOr<llvm::Value*> EmitCbrt(PrimitiveType prim_type,
132                                           llvm::Value* value);
133 
134   virtual StatusOr<llvm::Value*> EmitRsqrt(PrimitiveType prim_type,
135                                            llvm::Value* value);
136 
137   virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type,
138                                            llvm::Value* value);
139 
140   virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type,
141                                          llvm::Value* value);
142 
143   virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type,
144                                          llvm::Value* value);
145 
146   virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
147                                          llvm::Value* value,
148                                          absl::string_view name);
149 
150   virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
151                                            llvm::Value* value);
152 
153   virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
154                                          llvm::Value* lhs, llvm::Value* rhs,
155                                          absl::string_view name);
156 
157   virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
158                                           llvm::Value* value);
159 
160   virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
161                                                      llvm::Value* x);
162 
163   virtual StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
164   EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* operand_value,
165                        bool return_sqrt);
166 
167   virtual StatusOr<llvm::Value*> EmitComplexAbs(PrimitiveType prim_type,
168                                                 llvm::Value* operand_value);
169 
170   virtual StatusOr<llvm::Value*> EmitSqrtComplexAbs(PrimitiveType prim_type,
171                                                     llvm::Value* operand_value);
172   virtual StatusOr<llvm::Value*> EmitRsqrtComplexAbs(
173       PrimitiveType prim_type, llvm::Value* operand_value);
174 
175   virtual StatusOr<llvm::Value*> EmitComplexSqrt(const HloInstruction* op,
176                                                  PrimitiveType prim_type,
177                                                  llvm::Value* operand_value);
178 
179   virtual StatusOr<llvm::Value*> EmitComplexCbrt(const HloInstruction* op,
180                                                  PrimitiveType prim_type,
181                                                  llvm::Value* operand_value);
182 
183   virtual StatusOr<llvm::Value*> EmitComplexRsqrt(const HloInstruction* op,
184                                                   PrimitiveType prim_type,
185                                                   llvm::Value* operand_value);
186 
187   StatusOr<llvm::Value*> EmitAccumResult(
188       absl::Span<llvm::Value* const> accumulator_addrs,
189       llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic);
190 
191   // Composes a complex struct. imag may be nullptr for simple cast operations.
192   llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real,
193                                   llvm::Value* imag);
194 
195   // Emit `accumulator + lhs * rhs` for the given primitive type.
196   llvm::Value* EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
197                           llvm::Value* accumulator,
198                           xla::PrimitiveType primitive_type);
199 
200   // Identifier of the thread unique among all threads on the device
EmitThreadId()201   virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); }
202 
203   StatusOr<llvm::Value*> EmitElementalSelect(
204       const HloInstruction* hlo,
205       const HloToElementGeneratorMap& operand_to_generator,
206       const llvm_ir::IrArray::Index& index);
207 
208   StatusOr<llvm::Value*> EmitElementalClamp(
209       const HloInstruction* hlo,
210       const HloToElementGeneratorMap& operand_to_generator,
211       const llvm_ir::IrArray::Index& index);
212 
213   StatusOr<llvm::Value*> EmitElementalConcatenate(
214       const HloInstruction* hlo,
215       const HloToElementGeneratorMap& operand_to_generator,
216       const llvm_ir::IrArray::Index& target_index);
217 
218   StatusOr<llvm::Value*> EmitElementalDynamicSlice(
219       const HloInstruction* hlo,
220       const HloToElementGeneratorMap& operand_to_generator,
221       const llvm_ir::IrArray::Index& index);
222 
223   StatusOr<llvm::Value*> EmitElementalGather(
224       const HloInstruction* hlo,
225       const HloToElementGeneratorMap& operand_to_generator,
226       const llvm_ir::IrArray::Index& index);
227 
228   StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice(
229       const HloInstruction* hlo,
230       const HloToElementGeneratorMap& operand_to_generator,
231       const llvm_ir::IrArray::Index& index);
232 
233   StatusOr<llvm::Value*> EmitElementalPad(
234       const HloInstruction* hlo,
235       const HloToElementGeneratorMap& operand_to_generator,
236       const llvm_ir::IrArray::Index& padded_index);
237 
238   StatusOr<llvm::Value*> EmitElementalDot(
239       const HloInstruction* hlo,
240       const HloToElementGeneratorMap& operand_to_generator,
241       const llvm_ir::IrArray::Index& dot_result_index);
242 
243   virtual StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
244       const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
245       absl::string_view name) = 0;
246 
247   StatusOr<llvm::Value*> EmitElementalMap(
248       const HloMapInstruction* map_instr,
249       absl::Span<llvm::Value* const> elemental_operands);
250 
251   StatusOr<llvm::Value*> EmitElementalReduceWindow(
252       const HloReduceWindowInstruction* reduce_window,
253       std::vector<llvm_ir::ElementGenerator> input_generators,
254       std::vector<llvm_ir::ElementGenerator> initial_value_generators,
255       const llvm_ir::IrArray::Index& index);
256 
257   StatusOr<llvm::Value*> EmitElementalReduce(
258       const HloReduceInstruction* reduce,
259       std::vector<llvm_ir::ElementGenerator> input_generators,
260       std::vector<llvm_ir::ElementGenerator> initial_value_generators,
261       const llvm_ir::IrArray::Index& index);
262 
263   virtual StatusOr<llvm::Value*> EmitConvolution(
264       const HloInstruction* hlo,
265       const HloToElementGeneratorMap& operand_to_generator,
266       const llvm_ir::IrArray::Index& index);
267 
268   // Computes the complex power function, returns (a + i*b)^(c + i*d).
269   StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
270                                           llvm::Value* a, llvm::Value* b,
271                                           llvm::Value* c, llvm::Value* d);
272 
273   // Evaluates a polynomial using Horner's method.
274   StatusOr<llvm::Value*> EvaluatePolynomial(
275       llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
276 
277   virtual bool fast_min_max() = 0;
278 
279   llvm::IRBuilder<>* const b_;
280 
281   llvm::Module* module_;
282 };
283 
284 }  // namespace xla
285 
286 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
287