1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/strings/str_cat.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
34 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
39 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/window_util.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/lib/random/random.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/types.h"
50 
51 namespace xla {
52 
53 using absl::StrCat;
54 using llvm_ir::IrArray;
55 using llvm_ir::IrName;
56 using llvm_ir::SetToFirstInsertPoint;
57 
58 namespace {
59 
GlobalRandomValue()60 int64 GlobalRandomValue() {
61   static auto* mu = new tensorflow::mutex();
62   static std::mt19937_64 rng{42};
63   tensorflow::mutex_lock l(*mu);
64   return rng();
65 }
66 
EmitReducePrecisionIR(PrimitiveType src_ty,llvm::Value * x,int64 dest_exponent_bits,int64 dest_mantissa_bits,llvm::IRBuilder<> * b)67 StatusOr<llvm::Value*> EmitReducePrecisionIR(PrimitiveType src_ty,
68                                              llvm::Value* x,
69                                              int64 dest_exponent_bits,
70                                              int64 dest_mantissa_bits,
71                                              llvm::IRBuilder<>* b) {
72   using llvm::APInt;
73 
74   if (!primitive_util::IsFloatingPointType(src_ty)) {
75     return Unimplemented(
76         "ReducePrecision cannot accept non-floating-point type %s.",
77         PrimitiveType_Name(src_ty));
78   }
79 
80   // Integer and float types for casting and constant generation.
81   llvm::Type* float_type = x->getType();
82   int64 nbits = float_type->getPrimitiveSizeInBits();
83   llvm::IntegerType* int_type = b->getIntNTy(nbits);
84 
85   // SignificandWidth includes the implicit extra bit.
86   int src_mantissa_bits = primitive_util::SignificandWidth(src_ty) - 1;
87   int src_exponent_bits = nbits - 1 - src_mantissa_bits;
88 
89   // Cast the input value to an integer for bitwise manipulation.
90   llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
91 
92   if (dest_mantissa_bits < src_mantissa_bits) {
93     // Last remaining mantissa bit.
94     APInt last_mantissa_bit_mask(nbits, 1);
95     last_mantissa_bit_mask <<= src_mantissa_bits - dest_mantissa_bits;
96 
97     // Compute rounding bias for round-to-nearest with ties to even.  This is
98     // equal to a base value of 0111... plus one bit if the last remaining
99     // mantissa bit is 1.
100     APInt base_rounding_bias = last_mantissa_bit_mask.lshr(1) - 1;
101     llvm::Value* x_last_mantissa_bit = b->CreateLShr(
102         b->CreateAnd(x_as_int,
103                      llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
104         (src_mantissa_bits - dest_mantissa_bits));
105     llvm::Value* x_rounding_bias =
106         b->CreateAdd(x_last_mantissa_bit,
107                      llvm::ConstantInt::get(int_type, base_rounding_bias));
108 
109     // Add rounding bias, and mask out truncated bits.  Note that the case
110     // where adding the rounding bias overflows into the exponent bits is
111     // correct; the non-masked mantissa bits will all be zero, and the
112     // exponent will be incremented by one.
113     APInt truncation_mask = ~(last_mantissa_bit_mask - 1);
114     x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
115     x_as_int = b->CreateAnd(x_as_int,
116                             llvm::ConstantInt::get(int_type, truncation_mask));
117   }
118 
119   if (dest_exponent_bits < src_exponent_bits) {
120     APInt sign_bit_mask(nbits, 1);
121     sign_bit_mask <<= nbits - 1;
122 
123     APInt exp_bits_mask(nbits, 1);
124     exp_bits_mask = ((exp_bits_mask << src_exponent_bits) - 1)
125                     << src_mantissa_bits;
126 
127     // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
128     // significant bit -- is equal to 1.0f for all exponent sizes.  Adding
129     // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
130     // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
131     // exponent (corresponding to 0.0f).
132     //
133     // Thus, the f32 exponent corresponding to the highest non-infinite
134     // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
135     // exponent corresponding to the lowest exponent for a bit size of n is
136     // (2^7-1) - 2^(n-1)-1.
137     //
138     // Note that we have already checked that exponents_bits >= 1.
139     APInt exponent_bias(nbits, 1);
140     exponent_bias = (exponent_bias << (src_exponent_bits - 1)) - 1;
141 
142     APInt reduced_exponent_bias(nbits, 1);
143     reduced_exponent_bias =
144         (reduced_exponent_bias << (dest_exponent_bits - 1)) - 1;
145 
146     APInt reduced_max_exponent = exponent_bias + reduced_exponent_bias;
147     APInt reduced_min_exponent = exponent_bias - reduced_exponent_bias;
148 
149     // Do we overflow or underflow?
150     llvm::Value* x_exponent =
151         b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, exp_bits_mask));
152     llvm::Value* x_overflows = b->CreateICmpUGT(
153         x_exponent, llvm::ConstantInt::get(
154                         int_type, reduced_max_exponent << src_mantissa_bits));
155     llvm::Value* x_underflows = b->CreateICmpULE(
156         x_exponent, llvm::ConstantInt::get(
157                         int_type, reduced_min_exponent << src_mantissa_bits));
158 
159     // Compute appropriately-signed values of zero and infinity.
160     llvm::Value* x_signed_zero =
161         b->CreateAnd(x_as_int, llvm::ConstantInt::get(int_type, sign_bit_mask));
162     llvm::Value* x_signed_inf = b->CreateOr(
163         x_signed_zero, llvm::ConstantInt::get(int_type, exp_bits_mask));
164 
165     // Force to zero or infinity if overflow or underflow.  (Note that this
166     // truncates all denormal values to zero, rather than rounding them.)
167     x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
168     x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
169   }
170 
171   // Cast the result back to a floating-point type.
172   llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
173 
174   // Correct result for NaN inputs.
175   //
176   // The exponent handling will "normalize" NaN values to infinities, which is
177   // undesirable (except in the case with no mantissa bits, in which case it
178   // is mandatory).  This logic also handles cases where mantissa-rounding
179   // causes a NaN's mantissa to overflow into the exponent bits, which would
180   // otherwise create an erroneous zero value.
181   //
182   // If the fast-math flags are set to assume no NaNs, the comparison is likely
183   // to be optimized away, so there's no point in even emitting it.
184   if (!b->getFastMathFlags().noNaNs()) {
185     llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
186 
187     if (dest_mantissa_bits > 0) {
188       result = b->CreateSelect(x_is_nan, x, result);
189     } else {
190       result = b->CreateSelect(
191           x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
192     }
193   }
194   return result;
195 }
196 
EmitF32ToBF16(llvm::Value * f32_value,llvm::IRBuilder<> * b)197 StatusOr<llvm::Value*> EmitF32ToBF16(llvm::Value* f32_value,
198                                      llvm::IRBuilder<>* b) {
199   TF_ASSIGN_OR_RETURN(
200       auto reduced_precision,
201       EmitReducePrecisionIR(
202           /*src_ty=*/F32, f32_value,
203           /*dest_exponent_bits=*/primitive_util::ExponentWidth(BF16),
204           /*dest_mantissa_bits=*/primitive_util::SignificandWidth(BF16) - 1,
205           b));
206   auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
207   auto shifted = b->CreateLShr(as_int32, 16);
208   auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
209   return b->CreateBitCast(truncated, b->getInt16Ty());
210 }
211 
EmitBF16ToF32(llvm::Value * bf16_value,llvm::IRBuilder<> * b)212 llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
213   auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
214   auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
215   auto shifted = b->CreateShl(as_int32, 16);
216   return b->CreateBitCast(shifted, b->getFloatTy());
217 }
218 
EmitIntegralToFloating(llvm::Value * integer_value,PrimitiveType from_type,PrimitiveType to_type,llvm::Module * module,llvm::IRBuilder<> * b)219 llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
220                                     PrimitiveType from_type,
221                                     PrimitiveType to_type, llvm::Module* module,
222                                     llvm::IRBuilder<>* b) {
223   if (primitive_util::IsSignedIntegralType(from_type)) {
224     return b->CreateSIToFP(integer_value,
225                            llvm_ir::PrimitiveTypeToIrType(to_type, module));
226   } else {
227     CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
228           from_type == PRED);
229     return b->CreateUIToFP(integer_value,
230                            llvm_ir::PrimitiveTypeToIrType(to_type, module));
231   }
232 }
233 
234 }  // namespace
235 
EmitUnaryOp(const HloInstruction * op,llvm::Value * operand_value)236 StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
237     const HloInstruction* op, llvm::Value* operand_value) {
238   if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
239       op->operand(0)->shape().element_type() == PRED) {
240     return EmitIntegerUnaryOp(op, operand_value);
241   } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
242     return EmitComplexUnaryOp(op, operand_value);
243   } else {
244     return EmitFloatUnaryOp(op, operand_value);
245   }
246 }
247 
EmitIntegerUnaryOp(const HloInstruction * op,llvm::Value * operand_value)248 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
249     const HloInstruction* op, llvm::Value* operand_value) {
250   switch (op->opcode()) {
251     case HloOpcode::kConvert: {
252       PrimitiveType from_type = op->operand(0)->shape().element_type();
253       PrimitiveType to_type = op->shape().element_type();
254       CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
255           << from_type;
256       if (from_type == to_type) {
257         return operand_value;
258       }
259       if (to_type == PRED) {
260         return b_->CreateZExt(
261             ICmpNE(operand_value,
262                    llvm::ConstantInt::get(operand_value->getType(), 0)),
263             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
264       }
265       if (primitive_util::IsIntegralType(to_type)) {
266         return IntCast(operand_value,
267                        llvm_ir::PrimitiveTypeToIrType(to_type, module_),
268                        primitive_util::IsSignedIntegralType(from_type));
269       }
270       if (primitive_util::IsFloatingPointType(to_type)) {
271         if (to_type == BF16) {
272           return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
273                                                       F32, module_, b_),
274                                b_);
275         }
276         return EmitIntegralToFloating(operand_value, from_type, to_type,
277                                       module_, b_);
278       }
279       if (primitive_util::IsComplexType(to_type)) {
280         auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
281             primitive_util::ComplexComponentType(to_type), module_);
282         if (primitive_util::IsSignedIntegralType(from_type)) {
283           return EmitComposeComplex(
284               op, SIToFP(operand_value, to_ir_component_type), nullptr);
285         }
286         if (primitive_util::IsUnsignedIntegralType(from_type) ||
287             from_type == PRED) {
288           return EmitComposeComplex(
289               op, UIToFP(operand_value, to_ir_component_type), nullptr);
290         }
291       }
292       return Unimplemented("conversion from primitive type %s to %s",
293                            PrimitiveType_Name(from_type),
294                            PrimitiveType_Name(to_type));
295     }
296     case HloOpcode::kBitcastConvert: {
297       PrimitiveType from_type = op->operand(0)->shape().element_type();
298       PrimitiveType to_type = op->shape().element_type();
299       CHECK(primitive_util::IsIntegralType(from_type));
300       if (from_type == to_type) {
301         return operand_value;
302       }
303       if (primitive_util::BitWidth(from_type) ==
304           primitive_util::BitWidth(to_type)) {
305         return BitCast(operand_value,
306                        llvm_ir::PrimitiveTypeToIrType(to_type, module_));
307       }
308       return InvalidArgument(
309           "bitcast conversion from primitive type %s to %s with unequal "
310           "bit-widths (%u versus %u) ",
311           PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
312           primitive_util::BitWidth(from_type),
313           primitive_util::BitWidth(to_type));
314     }
315     case HloOpcode::kAbs: {
316       bool is_signed =
317           primitive_util::IsSignedIntegralType(op->shape().element_type());
318       if (is_signed) {
319         auto type =
320             llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
321         auto cmp = ICmpSGE(operand_value, GetZero(type));
322         return Select(cmp, operand_value, Neg(operand_value));
323       } else {
324         return operand_value;
325       }
326     }
327     case HloOpcode::kClz: {
328       auto is_zero_undef = b_->getFalse();
329       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
330                                           {operand_value, is_zero_undef},
331                                           {operand_value->getType()}, b_);
332     }
333     case HloOpcode::kSign: {
334       CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type()))
335           << op->shape().element_type();
336       auto type =
337           llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
338       auto cmp = ICmpEQ(operand_value, GetZero(type));
339       auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1);
340       return Select(cmp, GetZero(type), Or(ashr, 1));
341     }
342     case HloOpcode::kNegate:
343       return Neg(operand_value);
344     case HloOpcode::kNot: {
345       auto type = op->shape().element_type();
346       if (type == PRED) {
347         // It is not sufficient to just call CreateNot() here because a PRED
348         // is represented as an i8 and the truth value is stored only in the
349         // bottom bit.
350         return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())),
351                               llvm_ir::PrimitiveTypeToIrType(PRED, module_));
352       } else if (primitive_util::IsIntegralType(type)) {
353         return Not(operand_value);
354       }
355       return Unimplemented("unary op Not is not defined for type '%d'", type);
356     }
357     case HloOpcode::kPopulationCount: {
358       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
359                                           {operand_value},
360                                           {operand_value->getType()}, b_);
361     }
362     default:
363       return Unimplemented("unary integer op '%s'",
364                            HloOpcodeString(op->opcode()));
365   }
366 }
367 
EmitFloatUnaryOp(const HloInstruction * op,llvm::Value * operand_value)368 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
369     const HloInstruction* op, llvm::Value* operand_value) {
370   switch (op->opcode()) {
371     case HloOpcode::kConvert: {
372       PrimitiveType from_type = op->operand(0)->shape().element_type();
373       PrimitiveType to_type = op->shape().element_type();
374       CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
375       if (from_type == to_type) {
376         return operand_value;
377       }
378       if (from_type == BF16) {
379         TF_RET_CHECK(to_type != BF16);
380         operand_value = EmitBF16ToF32(operand_value, b_);
381         from_type = F32;
382         if (from_type == to_type) {
383           return operand_value;
384         }
385       }
386       if (primitive_util::IsComplexType(to_type)) {
387         PrimitiveType to_component_type =
388             primitive_util::ComplexComponentType(to_type);
389         if (from_type == to_component_type) {
390           return EmitComposeComplex(op, operand_value, nullptr);
391         }
392         return EmitComposeComplex(
393             op,
394             FPCast(operand_value,
395                    llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
396             nullptr);
397       }
398       if (to_type == BF16) {
399         // Cast to F32 first. Other floating point formats are not supported by
400         // EmitReducePrecisionIR.
401         if (from_type != F32) {
402           operand_value = b_->CreateFPCast(
403               operand_value, llvm_ir::PrimitiveTypeToIrType(F32, module_));
404         }
405         return EmitF32ToBF16(operand_value, b_);
406       }
407       if (to_type == PRED) {
408         return b_->CreateZExt(
409             FCmpUNE(operand_value,
410                     llvm::ConstantFP::get(operand_value->getType(), 0.0)),
411             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
412       }
413       if (primitive_util::IsFloatingPointType(to_type)) {
414         return FPCast(operand_value,
415                       llvm_ir::PrimitiveTypeToIrType(to_type, module_));
416       }
417       if (primitive_util::IsSignedIntegralType(to_type)) {
418         return FPToSI(operand_value,
419                       llvm_ir::PrimitiveTypeToIrType(to_type, module_));
420       }
421       if (primitive_util::IsUnsignedIntegralType(to_type)) {
422         return FPToUI(operand_value,
423                       llvm_ir::PrimitiveTypeToIrType(to_type, module_));
424       }
425       return Unimplemented("unhandled conversion operation: %s => %s",
426                            PrimitiveType_Name(from_type),
427                            PrimitiveType_Name(to_type));
428     }
429     case HloOpcode::kBitcastConvert: {
430       PrimitiveType from_type = op->operand(0)->shape().element_type();
431       PrimitiveType to_type = op->shape().element_type();
432       CHECK(primitive_util::IsFloatingPointType(from_type));
433       if (from_type == to_type) {
434         return operand_value;
435       }
436       if (primitive_util::BitWidth(from_type) ==
437           primitive_util::BitWidth(to_type)) {
438         return BitCast(operand_value,
439                        llvm_ir::PrimitiveTypeToIrType(to_type, module_));
440       }
441       return InvalidArgument(
442           "bitcast conversion from primitive type %s to %s with unequal "
443           "bit-widths (%u versus %u) ",
444           PrimitiveType_Name(from_type), PrimitiveType_Name(to_type),
445           primitive_util::BitWidth(from_type),
446           primitive_util::BitWidth(to_type));
447     }
448     case HloOpcode::kExp:
449       return EmitExp(op->shape().element_type(), operand_value, "");
450     case HloOpcode::kExpm1:
451       return EmitExpm1(op->shape().element_type(), operand_value);
452     case HloOpcode::kLog:
453       return EmitLog(op->shape().element_type(), operand_value);
454     case HloOpcode::kLog1p:
455       return EmitLog1p(op->shape().element_type(), operand_value);
456     case HloOpcode::kCos:
457       return EmitCos(op->shape().element_type(), operand_value);
458     case HloOpcode::kSin:
459       return EmitSin(op->shape().element_type(), operand_value);
460     case HloOpcode::kTanh:
461       return EmitTanh(op->shape().element_type(), operand_value);
462     case HloOpcode::kSqrt:
463       return EmitSqrt(op->shape().element_type(), operand_value);
464     case HloOpcode::kRsqrt:
465       return EmitRsqrt(op->shape().element_type(), operand_value);
466     case HloOpcode::kCbrt:
467       return EmitCbrt(op->shape().element_type(), operand_value);
468     case HloOpcode::kFloor:
469       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
470                                           {operand_value},
471                                           {operand_value->getType()}, b_);
472     case HloOpcode::kCeil:
473       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
474                                           {operand_value},
475                                           {operand_value->getType()}, b_);
476     case HloOpcode::kAbs:
477       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
478                                           {operand_value},
479                                           {operand_value->getType()}, b_);
480     case HloOpcode::kRoundNearestAfz:
481       return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round,
482                                           {operand_value},
483                                           {operand_value->getType()}, b_);
484     case HloOpcode::kSign: {
485       auto type = operand_value->getType();
486       auto zero = llvm::ConstantFP::get(type, 0.0);
487       auto ne0_i1 = FCmpONE(operand_value, zero);
488       auto ne0_float = UIToFP(ne0_i1, type);
489       llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
490           llvm::Intrinsic::copysign, {ne0_float, operand_value},
491           {operand_value->getType()}, b_);
492       auto is_nan = FCmpUNO(operand_value, operand_value);
493       result = Select(is_nan, operand_value, result);
494       return result;
495     }
496     case HloOpcode::kIsFinite: {
497       // abs(x) o!= inf, this works because the comparison returns false if
498       // either operand is NaN.
499       auto type = operand_value->getType();
500       auto abs_value = llvm_ir::EmitCallToIntrinsic(
501           llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
502       auto infinity = llvm::ConstantFP::getInfinity(type);
503       auto not_infinite = FCmpONE(abs_value, infinity);
504       return b_->CreateZExt(not_infinite,
505                             llvm_ir::PrimitiveTypeToIrType(PRED, module_));
506     }
507     case HloOpcode::kNegate:
508       return FNeg(operand_value);
509     case HloOpcode::kReal:
510       return operand_value;
511     case HloOpcode::kImag:
512       return llvm::ConstantFP::get(operand_value->getType(), 0.0);
513     default:
514       return Unimplemented("unary floating-point op '%s'",
515                            HloOpcodeString(op->opcode()));
516   }
517 }
518 
EmitComplexUnaryOp(const HloInstruction * op,llvm::Value * operand_value)519 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
520     const HloInstruction* op, llvm::Value* operand_value) {
521   PrimitiveType input_type = op->operand(0)->shape().element_type();
522   PrimitiveType component_type =
523       primitive_util::IsComplexType(input_type)
524           ? primitive_util::ComplexComponentType(input_type)
525           : input_type;
526   switch (op->opcode()) {
527     case HloOpcode::kLog: {
528       // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
529       auto a = EmitExtractReal(operand_value);
530       auto b = EmitExtractImag(operand_value);
531       TF_ASSIGN_OR_RETURN(llvm::Value * angle,
532                           EmitAtan2(component_type, b, a, ""));
533       TF_ASSIGN_OR_RETURN(llvm::Value * abs,
534                           EmitComplexAbs(component_type, operand_value));
535       TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
536       return EmitComposeComplex(op, log_abs, angle);
537     }
538     case HloOpcode::kLog1p: {
539       // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
540       auto a = EmitExtractReal(operand_value);
541       auto b = EmitExtractImag(operand_value);
542       llvm::Type* llvm_ty = a->getType();
543       auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
544       auto a_plus_one = FAdd(a, one);
545       auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
546       TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
547       TF_ASSIGN_OR_RETURN(auto angle,
548                           EmitAtan2(component_type, b, a_plus_one, ""));
549       auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
550       return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
551     }
552     case HloOpcode::kConvert: {
553       PrimitiveType from_type = op->operand(0)->shape().element_type();
554       TF_RET_CHECK(primitive_util::IsComplexType(from_type));
555       PrimitiveType to_type = op->shape().element_type();
556       TF_RET_CHECK(primitive_util::IsComplexType(to_type));
557       if (from_type == to_type) {
558         return operand_value;
559       }
560       PrimitiveType to_component_type =
561           primitive_util::ComplexComponentType(to_type);
562       auto to_ir_component_type =
563           llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
564       return EmitComposeComplex(
565           op, FPCast(EmitExtractReal(operand_value), to_ir_component_type),
566           FPCast(EmitExtractImag(operand_value), to_ir_component_type));
567     }
568     case HloOpcode::kExp: {
569       // e^(a+bi) = e^a*(cos(b)+sin(b)i)
570       TF_ASSIGN_OR_RETURN(
571           auto exp_a,
572           EmitExp(component_type, EmitExtractReal(operand_value), ""));
573       TF_ASSIGN_OR_RETURN(
574           auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
575       TF_ASSIGN_OR_RETURN(
576           auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
577       return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b));
578     }
579     case HloOpcode::kExpm1: {
580       // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
581       TF_ASSIGN_OR_RETURN(
582           auto exp_a,
583           EmitExp(component_type, EmitExtractReal(operand_value), ""));
584       TF_ASSIGN_OR_RETURN(
585           auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
586       TF_ASSIGN_OR_RETURN(
587           auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
588       auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
589       auto real_result = FSub(FMul(exp_a, cos_b), one);
590       auto imag_result = FMul(exp_a, sin_b);
591       return EmitComposeComplex(op, real_result, imag_result);
592     }
593     case HloOpcode::kCos: {
594       // cos(z) = .5(e^(iz) + e^(-iz))
595       // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
596       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
597       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
598       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
599       // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
600       //           = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
601       auto a = EmitExtractReal(operand_value);
602       auto b = EmitExtractImag(operand_value);
603       auto type = a->getType();
604       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
605       auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
606       auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
607       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
608       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
609       return EmitComposeComplex(op,
610                                 FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)),
611                                 FMul(sin_a, FSub(half_exp_neg_b, half_exp_b)));
612     }
613     case HloOpcode::kSin: {
614       // sin(z) = .5i(e^(-iz) - e^(iz))
615       // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
616       //           = .5i(e^(b-ai) - e^(-b+ai))
617       // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
618       // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
619       //           = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
620       // cos(-x) = cos(x) and sin(-x) = -sin(x), so
621       //           = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
622       //           = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
623       auto a = EmitExtractReal(operand_value);
624       auto b = EmitExtractImag(operand_value);
625       auto type = a->getType();
626       TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
627       auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
628       auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
629       TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
630       TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
631       return EmitComposeComplex(op,
632                                 FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)),
633                                 FMul(cos_a, FSub(half_exp_b, half_exp_neg_b)));
634     }
635     case HloOpcode::kTanh: {
636       /*
637       tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
638       e^(a+bi) = e^a*(cos(b)+sin(b)i)
639       so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
640               (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
641       cos(b)=cos(-b), sin(-b)=-sin(b)
642       so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
643               (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
644              =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
645               (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
646              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
647               (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
648       This is a complex division, so we can multiply by denom_conj/denom_conj
649              =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
650               (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
651               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
652              =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
653                i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
654               ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
655              =(e^(2a)-e^(-2a) +
656                i*[cos(b)sin(b)(e^(2a)+2+e^(-2a))-cos(b)sin(b)(e^(2a)-2+e^(2a)))]
657                / (cos(b)^2*(e^(2a)+2+e^(-2a)) + sin(b)^2*(e^(2a)-2+e^(2a))
658              =(e^(2a)-e^(-2a) +
659                i*cos(b)sin(b)*[e^(2a)+2+e^(-2a)-e^(2a)+2-e^(-2a)]) /
660                ([cos(b)^2 + sin(b)^2][e^(2a)+e^(-2a)])+2*[cos(b)^2 - sin(b)^2])
661              =(e^(2a)-e^(-2a) + i*cos(b)sin(b)*4) /
662               (e^(2a)+e^(-2a)+2*[cos(b)^2 - sin(b)^2])
663              =(e^(2a)-e^(-2a) + i*[sin(2b)/2]*4) /
664               (e^(2a)+e^(-2a)+2*[cos(2b)])
665              =(e^(2a)-e^(-2a) + i*2*sin(2b)) / (e^(2a) + e^(-2a) + 2*cos(2b))
666       */
667       llvm::Value* a = EmitExtractReal(operand_value);
668       llvm::Value* b = EmitExtractImag(operand_value);
669 
670       llvm::Type* type = a->getType();
671 
672       llvm::Value* neg_one = llvm::ConstantFP::get(type, -1.F);
673       llvm::Value* two_a = FAdd(a, a);
674       llvm::Value* neg_2a = FMul(neg_one, two_a);
675 
676       // When we are calculating the real numerator, e^(2a)-e^(-2a), for small
677       // values of `a`, we will get a ULP of 2^-23 using the exp function. Using
678       // expm1 to calculate e^(2a)-e^(-2a) = [e^(2a)-1] - [e^(-2a)-1] allows our
679       // ULP to be arbitrarily small. For larger values of `a`, calculating the
680       // numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually
681       // identical results.
682       TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1,
683                           EmitExpm1(component_type, two_a));
684       TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1,
685                           EmitExpm1(component_type, neg_2a));
686       llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1);
687 
688       // We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2
689       // = 2cos(b)^2. This gives us the ability to be more precise when the
690       // denominator is close to zero.
691       TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b));
692       llvm::Value* four = llvm::ConstantFP::get(type, 4.F);
693       llvm::Value* cos_b_sq = FMul(cos_b, cos_b);
694       llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four);
695 
696       // Similarly we can compute sin(2b) with the formula sin(2b) =
697       // 2*sin(b)*cos(b).
698       TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b));
699       llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b));
700 
701       // Expm1(x) is about x for small values of x, but exp_sum_m2 is about x^2
702       // for small value of x. As a result, due to floating point precision
703       // issues, x^2 is a better approximation than Expm1(x) + Expm1(x) for
704       // small values of x.
705       llvm::Value* a_sqr = FMul(a, a);
706       llvm::Value* use_approx_cutoff = llvm::ConstantFP::get(type, 1e-8);
707       llvm::Value* use_approx = FCmpOLT(a_sqr, use_approx_cutoff);
708 
709       llvm::Value* exp_sum_m2 =
710           Select(use_approx, a_sqr, FAdd(exp_2a_m1, exp_neg_2a_m1));
711       llvm::Value* denom = FAdd(exp_sum_m2, two_cos_2b_p2);
712 
713       // As `a` grows toward +inf and -inf, the real numerator will grow towards
714       // +inf and -inf respectively, while the denominator will always grow
715       // towards +inf. The result is real_numerator/denom = NaN, when it should
716       // equal +1 and -1 respectively. Therefore, if our denominator is +inf,
717       // we just hardcode the limits for the real numbers.
718       llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
719       llvm::Value* is_inf = FCmpOEQ(exp_sum_m2, inf);
720       llvm::Value* real_limit = llvm_ir::EmitCallToIntrinsic(
721           llvm::Intrinsic::copysign, {neg_one, a}, {type}, b_);
722 
723       llvm::Value* real =
724           Select(is_inf, real_limit, FDiv(real_numerator, denom));
725       llvm::Value* imag = FDiv(imag_numerator, denom);
726 
727       // The complex tanh functions have a few corner cases:
728       // 1. (+0, +0) => (+0, +0)        - Handled normally
729       // 2. (x, +Inf) => (NaN, NaN)     - See below
730       // 3. (x, NaN) => (NaN, NaN)      - See below
731       // 4. (+inf, y) => (1, +0)        - Handled normally
732       // 5. (+Inf, +Inf) => (1, +/-0)   - See below
733       // 6. (+Inf, NaN) => (1, +/-0)    - See below
734       // 7. (NaN, +0) => (NaN, +0)      - See below
735       // 8. (NaN, y) => (NaN, NaN)      - Handled normally
736       // 9. (NaN, NaN) => (NaN, NaN)    - Handled normally
737       //
738       // For the cases that aren't handled normally:
739       // 2/3) Part of the calculation we do is that if exp(a) + exp(-a) = +inf,
740       //      then we return (+/-1, +/-0). However, this is only true if we
741       //      assume that a is infinity or b is finite. In the event that both a
742       //      is finite and b is either +/-Inf or NaN, then our normal
743       //      calculation would end up returing (+/-1, NaN), as opposed to (NaN,
744       //      NaN).
745       // 5/6) We always calculate the imaginary value as sin(2b)/denominator.
746       //      When the denominator is infinity, this assures us that the zero is
747       //      the correct sign. However if our imaginary input results in
748       //      sin(2b) = NaN, we calculate our imaginary result as NaN.
749       // 7)   In the event that a is NaN, the denominator will be NaN.
750       //      Therefore, the normal calculation gives (NaN, NaN) while we need
751       //      (NaN, +0).
752       if (!(b_->getFastMathFlags().noNaNs() &&
753             b_->getFastMathFlags().noInfs())) {
754         llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
755                                                           {a}, {type}, b_);
756         llvm::Value* zero = llvm::ConstantFP::get(type, 0.F);
757         llvm::Value* nan = llvm::ConstantFP::getNaN(type);
758 
759         llvm::Value* a_is_inf = FCmpOEQ(abs_a, inf);
760         llvm::Value* b_is_zero = FCmpOEQ(b, zero);
761 
762         // imag_numerator = 2sin(2b), so sin(2b) is NaN if and only if
763         // imag_numerator is NaN.
764         llvm::Value* sin_2b_is_nan =
765             b_->CreateFCmpUNO(imag_numerator, imag_numerator);
766 
767         llvm::Value* real_is_nan =
768             b_->CreateAnd(sin_2b_is_nan, b_->CreateNot(a_is_inf));
769         llvm::Value* imag_is_zero =
770             b_->CreateOr(b_is_zero, b_->CreateAnd(a_is_inf, sin_2b_is_nan));
771 
772         real = Select(real_is_nan, nan, real);
773         imag = Select(imag_is_zero, zero, imag);
774       }
775 
776       return EmitComposeComplex(op, real, imag);
777     }
778     case HloOpcode::kAbs: {
779       return EmitComplexAbs(component_type, operand_value);
780     }
781     case HloOpcode::kSign: {  // Sign(c) = c / |c|
782       TF_ASSIGN_OR_RETURN(auto cplx_abs,
783                           EmitComplexAbs(component_type, operand_value));
784       auto type = cplx_abs->getType();
785       auto zero = llvm::ConstantFP::get(type, 0.0);
786       auto oeq = FCmpOEQ(cplx_abs, zero);
787       return Select(
788           oeq, EmitComposeComplex(op, zero, zero),
789           EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs),
790                              FDiv(EmitExtractImag(operand_value), cplx_abs)));
791     }
792     case HloOpcode::kSqrt: {
793       return EmitComplexSqrt(op, component_type, operand_value);
794     }
795     case HloOpcode::kRsqrt: {
796       return EmitComplexRsqrt(op, component_type, operand_value);
797     }
798     case HloOpcode::kCbrt: {
799       return EmitComplexCbrt(op, component_type, operand_value);
800     }
801     case HloOpcode::kNegate:
802       return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)),
803                                 FNeg(EmitExtractImag(operand_value)));
804     case HloOpcode::kReal:
805       return EmitExtractReal(operand_value);
806     case HloOpcode::kImag:
807       return EmitExtractImag(operand_value);
808     default:
809       return Unimplemented("unary complex op '%s'",
810                            HloOpcodeString(op->opcode()));
811   }
812 }
813 
EmitBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)814 StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
815     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
816   PrimitiveType operand_type = op->operand(0)->shape().element_type();
817   if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
818       operand_type == PRED) {
819     return EmitIntegerBinaryOp(
820         op, lhs_value, rhs_value,
821         primitive_util::IsSignedIntegralType(operand_type));
822   } else if (primitive_util::IsComplexType(operand_type)) {
823     return EmitComplexBinaryOp(op, lhs_value, rhs_value);
824   } else {
825     return EmitFloatBinaryOp(op, lhs_value, rhs_value);
826   }
827 }
828 
EmitFloatBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)829 StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
830     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
831   switch (op->opcode()) {
832     case HloOpcode::kComplex:
833       return EmitComposeComplex(op, lhs_value, rhs_value);
834     case HloOpcode::kAdd:
835       return FAdd(lhs_value, rhs_value, op->name());
836     case HloOpcode::kSubtract:
837       return FSub(lhs_value, rhs_value, op->name());
838     case HloOpcode::kMultiply:
839       return FMul(lhs_value, rhs_value, op->name());
840     case HloOpcode::kDivide:
841       return FDiv(lhs_value, rhs_value, op->name());
842     case HloOpcode::kRemainder:
843       return FRem(lhs_value, rhs_value, op->name());
844     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
845     // comparisons always return false when one of the operands is NaN, whereas
846     // unordered comparisons return true.
847     //
848     // We use ordered comparisons for everything except kNe, where we use an
849     // unordered comparison.  This makes x != y equivalent to !(x == y), and
850     // matches C++'s semantics.
851     case HloOpcode::kCompare: {
852       switch (op->comparison_direction()) {
853         case ComparisonDirection::kEq:
854           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
855                                          rhs_value, b_, op->name());
856         case ComparisonDirection::kNe:
857           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
858                                          rhs_value, b_, op->name());
859         case ComparisonDirection::kLt:
860           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
861                                          rhs_value, b_, op->name());
862         case ComparisonDirection::kGt:
863           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
864                                          rhs_value, b_, op->name());
865         case ComparisonDirection::kLe:
866           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
867                                          rhs_value, b_, op->name());
868         case ComparisonDirection::kGe:
869           return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
870                                          rhs_value, b_, op->name());
871       }
872     }
873     case HloOpcode::kMaximum:
874       return EmitFloatMax(lhs_value, rhs_value, op->name());
875     case HloOpcode::kMinimum:
876       return EmitFloatMin(lhs_value, rhs_value, op->name());
877     case HloOpcode::kPower:
878       return EmitPow(op->shape().element_type(), lhs_value, rhs_value,
879                      op->name());
880     case HloOpcode::kAtan2:
881       return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value,
882                        op->name());
883     default:
884       return Unimplemented("binary floating point op '%s'",
885                            HloOpcodeString(op->opcode()));
886   }
887 }
888 
889 // Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use
890 // sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2))
891 //                 = |a| * sqrt(1 + (b/a)^2)
892 // With the assumption that |a| >= |b|.
893 //
894 // This method returns the min, max, and sqrt term for this calculation. This is
895 // done to prevent potential overflow errors that can occur from multiplying the
896 // max with the sqrt term. (i.e. when calculating the sqrt of the absolute
897 // value, we can take the sqrt of the max and the sqrt term before multiplying
898 // them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of
899 // sqrt(1 + (b/a)^2).
900 StatusOr<std::tuple<llvm::Value*, llvm::Value*, llvm::Value*>>
EmitComplexAbsHelper(PrimitiveType prim_type,llvm::Value * operand_value,bool return_sqrt)901 ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
902                                          llvm::Value* operand_value,
903                                          bool return_sqrt) {
904   llvm::Value* real = EmitExtractReal(operand_value);
905   llvm::Value* imag = EmitExtractImag(operand_value);
906   llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic(
907       llvm::Intrinsic::fabs, {real}, {real->getType()}, b_);
908   llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic(
909       llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_);
910   llvm::Value* max = EmitFloatMax(abs_real, abs_imag, "");
911   llvm::Value* min = EmitFloatMin(abs_real, abs_imag, "");
912 
913   llvm::Value* div = FDiv(min, max);
914   llvm::Value* div_sq = FMul(div, div);
915   llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1);
916   llvm::Value* one_p_div_sq = FAdd(one, div_sq);
917   TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq));
918   return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq);
919 }
920 
EmitComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)921 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexAbs(
922     PrimitiveType prim_type, llvm::Value* operand_value) {
923   llvm::Value* min;
924   llvm::Value* max;
925   llvm::Value* sqrt;
926   TF_ASSIGN_OR_RETURN(
927       std::tie(min, max, sqrt),
928       EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
929   llvm::Value* result = FMul(max, sqrt);
930   // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
931   // In such cases, we return `min` instead of `result`.
932   return Select(FCmpUNO(result, result), min, result);
933 }
934 
935 // Calculates ComplexAbs in the same way, except using:
936 // sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25)
EmitSqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)937 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
938     PrimitiveType prim_type, llvm::Value* operand_value) {
939   llvm::Value* min;
940   llvm::Value* max;
941   llvm::Value* one_p_div_sq;
942   TF_ASSIGN_OR_RETURN(
943       std::tie(min, max, one_p_div_sq),
944       EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false));
945   TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
946   TF_ASSIGN_OR_RETURN(llvm::Value * pow,
947                       EmitPow(prim_type, one_p_div_sq,
948                               llvm::ConstantFP::get(max->getType(), .25), ""));
949   llvm::Value* result = FMul(sqrt_max, pow);
950   // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
951   // In such cases, we return `min` instead of `result`.
952   return Select(FCmpUNO(result, result), min, result);
953 }
954 
955 // Calculates ComplexAbs in the same way, except using:
956 // rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2))
EmitRsqrtComplexAbs(PrimitiveType prim_type,llvm::Value * operand_value)957 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrtComplexAbs(
958     PrimitiveType prim_type, llvm::Value* operand_value) {
959   llvm::Value* min;
960   llvm::Value* max;
961   llvm::Value* sqrt;
962   TF_ASSIGN_OR_RETURN(
963       std::tie(min, max, sqrt),
964       EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true));
965   TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max));
966   TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt));
967   llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt);
968   TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min));
969   // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
970   // In such cases, we return rsqrt(min) instead of `result`.
971   return Select(FCmpUNO(result, result), rsqrt_min, result);
972 }
973 
974 // Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get:
975 //   e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
976 // = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)]
977 // = r^0.5 * [cos(t/2) + i*sin(t/2)]
978 // = sqrt(r) * [cos(t/2) + i*sin(t/2)]
979 // where r = |a+bi| and t = atan2(b,a)
980 // TODO(bixia): See doc for implementation without atan2.
EmitComplexSqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)981 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(
982     const HloInstruction* op, PrimitiveType prim_type,
983     llvm::Value* operand_value) {
984   llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
985                          ->getElementType(0);
986 
987   TF_ASSIGN_OR_RETURN(llvm::Value * r,
988                       EmitSqrtComplexAbs(prim_type, operand_value));
989 
990   llvm::Value* a = EmitExtractReal(operand_value);
991   llvm::Value* b = EmitExtractImag(operand_value);
992   TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
993 
994   llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
995   llvm::Value* angle = FMul(t, c);
996   TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
997   TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
998 
999   llvm::Value* real_part;
1000   llvm::Value* imag_part;
1001 
1002   llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1003 
1004   if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1005     llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1006     llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1007     llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1008     llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1009                                                       {b}, {b->getType()}, b_);
1010 
1011     real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf,
1012                        Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)),
1013                               zero, FMul(r, cos)));
1014 
1015     llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic(
1016         llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_);
1017     imag_part =
1018         Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf,
1019                Select(FCmpUNO(r, r), nan,
1020                       Select(FCmpOEQ(sin, zero), sin, FMul(r, sin))));
1021   } else {
1022     real_part = FMul(r, cos);
1023     imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin));
1024   }
1025 
1026   return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero),
1027                 EmitComposeComplex(op, real_part, imag_part));
1028 }
1029 
1030 // Similar to Sqrt, we can use our EmitComplexPower formula, but set
1031 // c=-0.5 and d=0. We get:
1032 //   e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)]
1033 // = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)]
1034 // = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)]
1035 // = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)]
1036 // where r = |a+bi| and t = atan2(b,a).
EmitComplexRsqrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1037 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(
1038     const HloInstruction* op, PrimitiveType prim_type,
1039     llvm::Value* operand_value) {
1040   llvm::Type* type = static_cast<llvm::StructType*>(operand_value->getType())
1041                          ->getElementType(0);
1042 
1043   TF_ASSIGN_OR_RETURN(llvm::Value * r,
1044                       EmitRsqrtComplexAbs(prim_type, operand_value));
1045 
1046   llvm::Value* a = EmitExtractReal(operand_value);
1047   llvm::Value* b = EmitExtractImag(operand_value);
1048   TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));
1049 
1050   llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
1051   llvm::Value* angle = FMul(t, c);
1052   TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle));
1053   TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle));
1054 
1055   llvm::Value* real_part = FMul(r, cos);
1056   llvm::Value* imag_part = FMul(r, sin);
1057 
1058   if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) {
1059     llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1060     llvm::Value* neg_one = llvm::ConstantFP::get(type, -1);
1061     llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1062     llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1063     // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true);
1064     llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic(
1065         llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_);
1066     llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic(
1067         llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_);
1068     llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one);
1069 
1070     llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1071                                                       {a}, {a->getType()}, b_);
1072     llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
1073                                                       {b}, {b->getType()}, b_);
1074 
1075     llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1076     real_part = Select(
1077         is_zero_zero, inf,
1078         Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1079                a_signed_zero, FMul(r, cos)));
1080     imag_part = Select(
1081         is_zero_zero, nan,
1082         Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)),
1083                neg_b_signed_zero, FMul(r, sin)));
1084   } else {
1085     llvm::Value* zero = llvm::ConstantFP::get(type, 0);
1086     llvm::Value* inf = llvm::ConstantFP::getInfinity(type);
1087     llvm::Value* nan = llvm::ConstantFP::getNaN(type);
1088 
1089     llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero));
1090     real_part = Select(is_zero_zero, inf, FMul(r, cos));
1091     imag_part = Select(is_zero_zero, nan, FMul(r, sin));
1092   }
1093 
1094   return EmitComposeComplex(op, real_part, imag_part);
1095 }
1096 
1097 //
1098 // Using EmitComplexPower with c=1.0/3.0 and d=0
EmitComplexCbrt(const HloInstruction * op,PrimitiveType prim_type,llvm::Value * operand_value)1099 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexCbrt(
1100     const HloInstruction* op, PrimitiveType prim_type,
1101     llvm::Value* operand_value) {
1102   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1103   auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1104   auto zero = llvm::ConstantFP::get(type, 0);
1105   llvm::Value* a = EmitExtractReal(operand_value);
1106   llvm::Value* b = EmitExtractImag(operand_value);
1107   return EmitComplexPower(op, a, b, third, zero);
1108 }
1109 
1110 // (a+bi)^(c+di) =
1111 //    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1112 //    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
EmitComplexPower(const HloInstruction * op,llvm::Value * a,llvm::Value * b,llvm::Value * c,llvm::Value * d)1113 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
1114     const HloInstruction* op, llvm::Value* a, llvm::Value* b, llvm::Value* c,
1115     llvm::Value* d) {
1116   PrimitiveType component_type =
1117       primitive_util::ComplexComponentType(op->shape().element_type());
1118   auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b));
1119   auto zero = llvm::ConstantFP::get(a->getType(), 0);
1120   auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
1121   auto one = llvm::ConstantFP::get(a->getType(), 1);
1122   auto half_c = FMul(one_half, c);
1123 
1124   TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
1125                       EmitPow(component_type, aa_p_bb, half_c, ""));
1126 
1127   auto neg_d = FNeg(d);
1128   TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, ""));
1129   auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
1130   TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
1131                       EmitExp(component_type, neg_d_arg_lhs, ""));
1132   auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
1133   TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
1134   auto half_d = FMul(one_half, d);
1135   auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb));
1136   TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
1137   TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
1138   // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1139   // Branch Cuts for Complex Elementary Functions or Much Ado About
1140   // Nothing's Sign Bit, W. Kahan, Section 10.
1141   return Select(
1142       And(And(FCmpOEQ(aa_p_bb, zero), FCmpOEQ(d, zero)), FCmpOLE(zero, c)),
1143       EmitComposeComplex(op, Select(FCmpOEQ(zero, c), one, zero), zero),
1144       EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)));
1145 }
1146 
EmitComplexBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value)1147 StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
1148     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) {
1149   switch (op->opcode()) {
1150     case HloOpcode::kAdd:
1151       return EmitComposeComplex(
1152           op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1153           FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1154     case HloOpcode::kSubtract:
1155       return EmitComposeComplex(
1156           op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1157           FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
1158     case HloOpcode::kMultiply:
1159       return EmitComposeComplex(
1160           op,
1161           FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
1162                FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))),
1163           FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
1164                FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))));
1165     case HloOpcode::kDivide: {
1166       // Division of complex numbers is implemented here, taking into account
1167       // over/underflow, NaN and Inf values.
1168       auto a_r = EmitExtractReal(lhs_value);
1169       auto a_i = EmitExtractImag(lhs_value);
1170       auto b_r = EmitExtractReal(rhs_value);
1171       auto b_i = EmitExtractImag(rhs_value);
1172       auto type = a_r->getType();
1173 
1174       // Smith's algorithm to divide complex numbers. It is just a bit smarter
1175       // way to compute the following formula:
1176       //  (a_r + a_i * i) / (b_r + b_i * i)
1177       //    = (a_r + a_i * i) (b_r - b_i * i) / ((b_r + b_i * i)(b_r - b_i * i))
1178       //    = ((a_r * b_r + a_i * b_i) + (a_i * b_r - a_r * b_i) * i) / ||b||^2
1179       //
1180       // Depending on whether |b_r| < |b_i| we compute either
1181       //   b_r_b_i_ratio = b_r / b_i
1182       //   b_r_b_i_denom = b_i + b_r * b_r_b_i_ratio
1183       //   c_r = (a_r * b_r_b_i_ratio + a_i ) / b_r_b_i_denom
1184       //   c_i = (a_i * b_r_b_i_ratio - a_r ) / b_r_b_i_denom
1185       //
1186       // or
1187       //
1188       //   b_i_b_r_ratio = b_i / b_r
1189       //   b_i_b_r_denom = b_r + b_i * b_i_b_r_denom
1190       //   c_r = (a_r + a_i * b_i_b_r_ratio ) / b_i_b_r_denom
1191       //   c_i = (a_i - a_r * b_i_b_r_ratio ) / b_i_b_r_denom
1192       //
1193       // See https://dl.acm.org/citation.cfm?id=368661 for more details.
1194       auto b_r_b_i_ratio = FDiv(b_r, b_i);
1195       auto b_r_b_i_denom = FAdd(b_i, FMul(b_r_b_i_ratio, b_r));
1196       auto b_i_b_r_ratio = FDiv(b_i, b_r);
1197       auto b_i_b_r_denom = FAdd(b_r, FMul(b_i_b_r_ratio, b_i));
1198 
1199       auto b_r_lt_b_i = FCmpOLT(llvm_ir::EmitCallToIntrinsic(
1200                                     llvm::Intrinsic::fabs, {b_r}, {type}, b_),
1201                                 llvm_ir::EmitCallToIntrinsic(
1202                                     llvm::Intrinsic::fabs, {b_i}, {type}, b_));
1203       auto c_r = Select(
1204           b_r_lt_b_i, FDiv(FAdd(FMul(b_r_b_i_ratio, a_r), a_i), b_r_b_i_denom),
1205           FDiv(FAdd(FMul(b_i_b_r_ratio, a_i), a_r), b_i_b_r_denom));
1206       auto c_i = Select(
1207           b_r_lt_b_i, FDiv(FSub(FMul(b_r_b_i_ratio, a_i), a_r), b_r_b_i_denom),
1208           FDiv(FSub(a_i, FMul(b_i_b_r_ratio, a_r)), b_i_b_r_denom));
1209       auto result = EmitComposeComplex(op, c_r, c_i);
1210 
1211       // Consider corner cases, if the result is (NaN, NaN).
1212       auto zero = llvm::ConstantFP::get(type, 0.0);
1213       auto one = llvm::ConstantFP::get(type, 1.0);
1214       auto inf = llvm::ConstantFP::getInfinity(type);
1215 
1216       // Case 1. Zero denominator.
1217       auto zero_denominator =
1218           And(And(FCmpOEQ(b_r, zero), FCmpOEQ(b_i, zero)),
1219               Or(Neg(FCmpONE(a_r, zero)), Neg(FCmpONE(a_i, zero))));
1220       auto inf_with_sign_of_c = llvm_ir::EmitCallToIntrinsic(
1221           llvm::Intrinsic::copysign, {inf, a_r}, {type}, b_);
1222       auto zero_denominator_result = EmitComposeComplex(
1223           op, FMul(inf_with_sign_of_c, a_r), FMul(inf_with_sign_of_c, a_i));
1224 
1225       // Case 2. Infinite numerator, finite denominator.
1226       auto b_r_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
1227                                     llvm::Intrinsic::fabs, {b_r}, {type}, b_),
1228                                 inf);
1229       auto b_i_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
1230                                     llvm::Intrinsic::fabs, {b_i}, {type}, b_),
1231                                 inf);
1232       auto inf_num_finite_denom = And(Or(FCmpOEQ(a_r, inf), FCmpOEQ(a_i, inf)),
1233                                       And(b_r_finite, b_i_finite));
1234 
1235       auto a_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1236           llvm::Intrinsic::copysign,
1237           {Select(FCmpOEQ(a_r, inf), one, zero), a_r}, {type}, b_);
1238       auto a_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1239           llvm::Intrinsic::copysign,
1240           {Select(FCmpOEQ(a_i, inf), one, zero), a_i}, {type}, b_);
1241       auto inf_num_finite_denom_result =
1242           EmitComposeComplex(op,
1243                              FMul(inf, FAdd(FMul(a_r_inf_with_sign, b_r),
1244                                             FMul(a_i_inf_with_sign, b_i))),
1245                              FMul(inf, FSub(FMul(a_i_inf_with_sign, b_r),
1246                                             FMul(a_r_inf_with_sign, b_i))));
1247 
1248       // Case 3. Finite numerator, infinite denominator.
1249       auto a_r_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
1250                                     llvm::Intrinsic::fabs, {a_r}, {type}, b_),
1251                                 inf);
1252       auto a_i_finite = FCmpONE(llvm_ir::EmitCallToIntrinsic(
1253                                     llvm::Intrinsic::fabs, {a_i}, {type}, b_),
1254                                 inf);
1255       auto finite_num_inf_denom = And(Or(FCmpOEQ(b_r, inf), FCmpOEQ(b_i, inf)),
1256                                       And(a_r_finite, a_i_finite));
1257 
1258       auto b_r_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1259           llvm::Intrinsic::copysign,
1260           {Select(FCmpOEQ(b_r, inf), one, zero), b_r}, {type}, b_);
1261       auto b_i_inf_with_sign = llvm_ir::EmitCallToIntrinsic(
1262           llvm::Intrinsic::copysign,
1263           {Select(FCmpOEQ(b_i, inf), one, zero), b_i}, {type}, b_);
1264       auto finite_num_inf_denom_result =
1265           EmitComposeComplex(op,
1266                              FMul(zero, FAdd(FMul(a_r, b_r_inf_with_sign),
1267                                              FMul(a_i, b_i_inf_with_sign))),
1268                              FMul(zero, FSub(FMul(a_i, b_r_inf_with_sign),
1269                                              FMul(a_r, b_i_inf_with_sign))));
1270 
1271       auto c_nan = And(FCmpUNO(c_r, zero), FCmpUNO(c_i, zero));
1272       return Select(
1273           c_nan,
1274           Select(zero_denominator, zero_denominator_result,
1275                  Select(inf_num_finite_denom, inf_num_finite_denom_result,
1276                         Select(finite_num_inf_denom,
1277                                finite_num_inf_denom_result, result))),
1278           result);
1279     }
1280     // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
1281     // comparisons always return false when one of the operands is NaN, whereas
1282     // unordered comparisons return true.
1283     //
1284     // We use ordered comparisons for everything except kNe, where we use an
1285     // unordered comparison.  This makes x != y equivalent to !(x == y), and
1286     // matches C++'s semantics.
1287     case HloOpcode::kCompare: {
1288       switch (op->comparison_direction()) {
1289         case ComparisonDirection::kEq:
1290           return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1291                                              EmitExtractReal(lhs_value),
1292                                              EmitExtractReal(rhs_value), b_),
1293                      llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
1294                                              EmitExtractImag(lhs_value),
1295                                              EmitExtractImag(rhs_value), b_));
1296         case ComparisonDirection::kNe:
1297           return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1298                                             EmitExtractReal(lhs_value),
1299                                             EmitExtractReal(rhs_value), b_),
1300                     llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
1301                                             EmitExtractImag(lhs_value),
1302                                             EmitExtractImag(rhs_value), b_));
1303         default:
1304           return Unimplemented(
1305               "complex comparison '%s'",
1306               ComparisonDirectionToString(op->comparison_direction()));
1307       }
1308     }
1309     case HloOpcode::kPower: {
1310       auto a = EmitExtractReal(lhs_value);
1311       auto b = EmitExtractImag(lhs_value);
1312       auto c = EmitExtractReal(rhs_value);
1313       auto d = EmitExtractImag(rhs_value);
1314       return EmitComplexPower(op, a, b, c, d);
1315     }
1316     default:
1317       return Unimplemented("binary complex op '%s'",
1318                            HloOpcodeString(op->opcode()));
1319   }
1320 }
1321 
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1322 llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
1323                                               llvm::Value* rhs_value,
1324                                               absl::string_view name) {
1325   return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
1326 }
1327 
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,absl::string_view name)1328 llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
1329                                               llvm::Value* rhs_value,
1330                                               absl::string_view name) {
1331   return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
1332 }
1333 
EmitLog(PrimitiveType prim_type,llvm::Value * value)1334 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
1335                                                    llvm::Value* value) {
1336   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
1337                                       {value->getType()}, b_);
1338 }
1339 
EmitLog1p(PrimitiveType prim_type,llvm::Value * value)1340 StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
1341                                                      llvm::Value* value) {
1342   auto x = value;
1343   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1344   auto one = llvm::ConstantFP::get(type, 1.0);
1345   auto negative_half = llvm::ConstantFP::get(type, -0.5);
1346   // When x is large, the naive evaluation of ln(x + 1) is more
1347   // accurate than the Taylor series.
1348   TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
1349   // When x is small, (defined to be less than sqrt(2) / 2), use a rational
1350   // approximation. The approximation below is based on one from the Cephes
1351   // Mathematical Library.
1352   //
1353   // sqrt(2) - 1.
1354   const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
1355 
1356   static const std::array<double, 7> kDenominatorCoeffs{
1357       1.,
1358       1.5062909083469192043167E1,
1359       8.3047565967967209469434E1,
1360       2.2176239823732856465394E2,
1361       3.0909872225312059774938E2,
1362       2.1642788614495947685003E2,
1363       6.0118660497603843919306E1,
1364   };
1365 
1366   static const std::array<double, 7> kNumeratorCoeffs{
1367       4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
1368       6.5787325942061044846969E0,  2.9911919328553073277375E1,
1369       6.0949667980987787057556E1,  5.7112963590585538103336E1,
1370       2.0039553499201281259648E1,
1371   };
1372 
1373   auto x_squared = FMul(x, x);
1374   TF_ASSIGN_OR_RETURN(auto denominator,
1375                       EvaluatePolynomial(type, x, kDenominatorCoeffs));
1376   TF_ASSIGN_OR_RETURN(auto numerator,
1377                       EvaluatePolynomial(type, x, kNumeratorCoeffs));
1378   auto for_small_x = FDiv(numerator, denominator);
1379   for_small_x = FMul(FMul(x, x_squared), for_small_x);
1380   for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
1381   for_small_x = FAdd(x, for_small_x);
1382 
1383   auto abs_x =
1384       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1385   auto x_is_small = FCmpOLT(
1386       abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
1387   return Select(x_is_small, for_small_x, for_large_x);
1388 }
1389 
EmitSqrt(PrimitiveType,llvm::Value * value)1390 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrt(PrimitiveType,
1391                                                     llvm::Value* value) {
1392   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value},
1393                                       {value->getType()}, b_);
1394 }
1395 
EmitRsqrt(PrimitiveType prim_type,llvm::Value * value)1396 StatusOr<llvm::Value*> ElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,
1397                                                      llvm::Value* value) {
1398   TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, value));
1399   return FDiv(llvm::ConstantFP::get(sqrt->getType(), 1.0), sqrt);
1400 }
1401 
EmitSin(PrimitiveType prim_type,llvm::Value * value)1402 StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
1403                                                    llvm::Value* value) {
1404   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
1405                                       {value->getType()}, b_);
1406 }
1407 
EmitCos(PrimitiveType prim_type,llvm::Value * value)1408 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
1409                                                    llvm::Value* value) {
1410   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
1411                                       {value->getType()}, b_);
1412 }
1413 
EmitExp(PrimitiveType prim_type,llvm::Value * value,absl::string_view name)1414 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
1415                                                    llvm::Value* value,
1416                                                    absl::string_view name) {
1417   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
1418                                       {value->getType()}, b_, name);
1419 }
1420 
EmitExpm1(PrimitiveType prim_type,llvm::Value * value)1421 StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
1422                                                      llvm::Value* value) {
1423   auto x = value;
1424   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1425   auto one = llvm::ConstantFP::get(type, 1.0);
1426   auto half = llvm::ConstantFP::get(type, 0.5);
1427   // When the exponent is large, the naive evaluation of e^(x) - 1 is more
1428   // accurate than the Taylor series.
1429   TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value, ""));
1430   auto for_large_x = FSub(exp_x, one);
1431   // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
1432   // We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
1433   // We use the second degree approximation of exp(x)-1 = x + x^2/2.
1434   auto x_squared = FMul(x, x);
1435   auto x_squared_over_two = FMul(x_squared, half);
1436   auto for_small_x = FAdd(x, x_squared_over_two);
1437   // At this point, the relative errors due to floating point precision loss of
1438   // calculating exp(x) - 1 and the polynomial exp(x)-1 = x + x^2/2 are about
1439   // equal, with a value of approximately 2^-16.
1440   const auto kExponentIsSmallThreshold = 0.009;
1441   auto abs_x =
1442       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1443   auto x_is_small =
1444       FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
1445   return Select(x_is_small, for_small_x, for_large_x);
1446 }
1447 
EmitPow(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value * rhs,absl::string_view name)1448 StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
1449                                                    llvm::Value* lhs,
1450                                                    llvm::Value* rhs,
1451                                                    absl::string_view name) {
1452   return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
1453                                       {lhs->getType()}, b_, name);
1454 }
1455 
EmitCbrt(PrimitiveType prim_type,llvm::Value * value)1456 StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
1457                                                     llvm::Value* value) {
1458   auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
1459   auto third = llvm::ConstantFP::get(type, 1.0 / 3.0);
1460   auto abs_value =
1461       llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
1462   TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
1463                       EmitPow(prim_type, abs_value, third, ""));
1464   auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
1465                                                  {abs_res, value}, {type}, b_);
1466   return signed_res;
1467 }
1468 
EmitAtan2(PrimitiveType prim_type,llvm::Value * lhs,llvm::Value *,absl::string_view)1469 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(
1470     PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* /*rhs*/,
1471     absl::string_view /*name*/) {
1472   return Unimplemented("atan2");
1473 }
1474 
EmitTanh(PrimitiveType prim_type,llvm::Value * value)1475 StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
1476                                                     llvm::Value* value) {
1477   return Unimplemented("tanh");
1478 }
1479 
EmitReducePrecision(const HloInstruction * hlo,llvm::Value * x)1480 StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
1481     const HloInstruction* hlo, llvm::Value* x) {
1482   return EmitReducePrecisionIR(
1483       /*src_ty=*/hlo->operand(0)->shape().element_type(), x,
1484       /*dest_exponent_bits=*/hlo->exponent_bits(),
1485       /*dest_mantissa_bits=*/hlo->mantissa_bits(), b_);
1486 }
1487 
SaturateShiftIfNecessary(llvm::IRBuilder<> * b,llvm::Value * lhs,llvm::Value * rhs,llvm::Value * shift_result,bool saturate_to_sign_bit)1488 static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
1489                                              llvm::Value* lhs, llvm::Value* rhs,
1490                                              llvm::Value* shift_result,
1491                                              bool saturate_to_sign_bit) {
1492   llvm::IntegerType* integer_type =
1493       llvm::cast<llvm::IntegerType>(lhs->getType());
1494   unsigned integer_bitsize = integer_type->getBitWidth();
1495   llvm::ConstantInt* integer_bitsize_constant =
1496       llvm::ConstantInt::get(integer_type, integer_bitsize);
1497   llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0);
1498   llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
1499   llvm::Value* saturated_value;
1500   if (saturate_to_sign_bit) {
1501     saturated_value =
1502         b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
1503   } else {
1504     saturated_value = zero;
1505   }
1506   llvm::Value* shift_amt_in_range =
1507       b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
1508   return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
1509 }
1510 
GetOne(llvm::Type * type)1511 llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) {
1512   return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
1513 }
1514 
GetZero(llvm::Type * type)1515 llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) {
1516   return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
1517 }
1518 
GetIntSMin(llvm::Type * type)1519 llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) {
1520   auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1521   return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
1522                                                   integer_type->getBitWidth()));
1523 }
1524 
GetMinusOne(llvm::Type * type)1525 llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) {
1526   auto* integer_type = llvm::cast<llvm::IntegerType>(type);
1527   return llvm::ConstantInt::get(
1528       integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
1529 }
1530 
IsZero(llvm::Value * v)1531 llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) {
1532   return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
1533 }
1534 
IsIntMinDivisionOverflow(llvm::Value * lhs,llvm::Value * rhs)1535 llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs,
1536                                                           llvm::Value* rhs) {
1537   return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())),
1538              ICmpEQ(rhs, GetMinusOne(rhs->getType())));
1539 }
1540 
EmitIntegerDivide(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1541 llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
1542                                                    llvm::Value* rhs,
1543                                                    bool is_signed) {
1544   // Integer division overflow behavior:
1545   //
1546   // X / 0 == -1
1547   // INT_SMIN /s -1 = INT_SMIN
1548 
1549   if (!is_signed) {
1550     llvm::Value* udiv_is_unsafe = IsZero(rhs);
1551     llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
1552     llvm::Value* safe_div = UDiv(lhs, safe_rhs);
1553     return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
1554   }
1555 
1556   llvm::Value* has_zero_divisor = IsZero(rhs);
1557   llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1558   llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1559   llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
1560   llvm::Value* safe_div = SDiv(lhs, safe_rhs);
1561 
1562   return Select(
1563       has_zero_divisor, GetMinusOne(lhs->getType()),
1564       Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
1565 }
1566 
EmitIntegerRemainder(llvm::Value * lhs,llvm::Value * rhs,bool is_signed)1567 llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
1568                                                       llvm::Value* rhs,
1569                                                       bool is_signed) {
1570   // Integer remainder overflow behavior:
1571   //
1572   // X % 0 == X
1573   // INT_SMIN %s -1 = 0
1574 
1575   if (!is_signed) {
1576     llvm::Value* urem_is_unsafe = IsZero(rhs);
1577     llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
1578     llvm::Value* safe_rem = URem(lhs, safe_rhs);
1579     return Select(urem_is_unsafe, lhs, safe_rem);
1580   }
1581 
1582   llvm::Value* has_zero_divisor = IsZero(rhs);
1583   llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
1584   llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor);
1585   llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
1586   llvm::Value* safe_rem = SRem(lhs, safe_rhs);
1587 
1588   return Select(
1589       has_zero_divisor, lhs,
1590       Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
1591 }
1592 
EmitIntegerPow(llvm::Value * base,llvm::Value * exponent,bool is_signed)1593 llvm::Value* ElementalIrEmitter::EmitIntegerPow(llvm::Value* base,
1594                                                 llvm::Value* exponent,
1595                                                 bool is_signed) {
1596   // Exponentiation by squaring:
1597   // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
1598   int bits = 6;  // Everything else would overflow for any exponent > 1, as 2^64
1599                  // is the larget possible exponent for a 64-bit integer, and
1600                  // that's 1 << 6.
1601   llvm::Value* accumulator = llvm::ConstantInt::get(base->getType(), 1);
1602   llvm::Value* one = llvm::ConstantInt::get(exponent->getType(), 1);
1603   llvm::Value* zero = llvm::ConstantInt::get(exponent->getType(), 0);
1604   llvm::Value* original_base = base;
1605   llvm::Value* original_exponent = exponent;
1606 
1607   // Unroll the loop at compile time.
1608   for (int i = 0; i < bits; i++) {
1609     accumulator =
1610         b_->CreateSelect(b_->CreateICmpEQ(b_->CreateAnd(exponent, one), one),
1611                          b_->CreateMul(accumulator, base), accumulator);
1612     base = b_->CreateMul(base, base);
1613     exponent = b_->CreateLShr(exponent, 1);
1614   }
1615   return b_->CreateSelect(
1616       b_->CreateICmpSGE(original_exponent, zero), accumulator,
1617       b_->CreateSelect(b_->CreateICmpEQ(original_base, one), one, zero));
1618 }
1619 
EmitIntegerBinaryOp(const HloInstruction * op,llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1620 StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
1621     const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
1622     bool is_signed) {
1623   switch (op->opcode()) {
1624     // TODO(jingyue): add the "nsw" attribute for signed types.
1625     case HloOpcode::kAdd:
1626       return Add(lhs_value, rhs_value);
1627     case HloOpcode::kSubtract:
1628       return Sub(lhs_value, rhs_value);
1629     case HloOpcode::kMultiply:
1630       return Mul(lhs_value, rhs_value);
1631     case HloOpcode::kDivide:
1632       return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
1633     case HloOpcode::kRemainder:
1634       return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
1635     case HloOpcode::kCompare: {
1636       switch (op->comparison_direction()) {
1637         case ComparisonDirection::kEq:
1638           return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
1639                                          rhs_value, b_);
1640         case ComparisonDirection::kNe:
1641           return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
1642                                          rhs_value, b_);
1643         case ComparisonDirection::kLt:
1644           return llvm_ir::EmitComparison(
1645               is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
1646               lhs_value, rhs_value, b_);
1647         case ComparisonDirection::kGt:
1648           return llvm_ir::EmitComparison(
1649               is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
1650               lhs_value, rhs_value, b_);
1651         case ComparisonDirection::kLe:
1652           return llvm_ir::EmitComparison(
1653               is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
1654               lhs_value, rhs_value, b_);
1655         case ComparisonDirection::kGe:
1656           return llvm_ir::EmitComparison(
1657               is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
1658               lhs_value, rhs_value, b_);
1659       }
1660     }
1661     case HloOpcode::kMinimum:
1662       return EmitIntegralMin(lhs_value, rhs_value, is_signed);
1663     case HloOpcode::kMaximum:
1664       return EmitIntegralMax(lhs_value, rhs_value, is_signed);
1665     case HloOpcode::kAnd:
1666       return And(lhs_value, rhs_value);
1667     case HloOpcode::kOr:
1668       return Or(lhs_value, rhs_value);
1669     case HloOpcode::kPower:
1670       return EmitIntegerPow(lhs_value, rhs_value, is_signed);
1671     case HloOpcode::kXor:
1672       return Xor(lhs_value, rhs_value);
1673 
1674     // Shifting out bits >= the number of bits in the type being shifted
1675     // produces a poison value in LLVM which is basically "deferred undefined
1676     // behavior" -- doing something observable with such a value precipitates
1677     // UB.  We replace the poison value with a constant to avoid this deferred
1678     // UB.
1679     case HloOpcode::kShiftRightArithmetic:
1680       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1681                                       AShr(lhs_value, rhs_value),
1682                                       /*saturate_to_sign_bit=*/true);
1683     case HloOpcode::kShiftLeft:
1684       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1685                                       Shl(lhs_value, rhs_value),
1686                                       /*saturate_to_sign_bit=*/false);
1687     case HloOpcode::kShiftRightLogical:
1688       return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
1689                                       LShr(lhs_value, rhs_value),
1690                                       /*saturate_to_sign_bit=*/false);
1691     default:
1692       return Unimplemented("binary integer op '%s'",
1693                            HloOpcodeString(op->opcode()));
1694   }
1695 }
1696 
EmitIntegralMax(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1697 llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
1698                                                  llvm::Value* rhs_value,
1699                                                  bool is_signed) {
1700   return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
1701                                          : llvm::ICmpInst::ICMP_UGE,
1702                                lhs_value, rhs_value),
1703                 lhs_value, rhs_value);
1704 }
1705 
EmitIntegralMin(llvm::Value * lhs_value,llvm::Value * rhs_value,bool is_signed)1706 llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
1707                                                  llvm::Value* rhs_value,
1708                                                  bool is_signed) {
1709   return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
1710                                          : llvm::ICmpInst::ICMP_ULE,
1711                                lhs_value, rhs_value),
1712                 lhs_value, rhs_value);
1713 }
1714 
EmitElementalSelect(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1715 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
1716     const HloInstruction* hlo,
1717     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1718     const llvm_ir::IrArray::Index& index) {
1719   TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
1720                       operand_to_generator.at(hlo->operand(0))(index));
1721   TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
1722                       operand_to_generator.at(hlo->operand(1))(index));
1723   TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
1724                       operand_to_generator.at(hlo->operand(2))(index));
1725   return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value,
1726                 on_false_value);
1727 }
1728 
EmitElementalClamp(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1729 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
1730     const HloInstruction* hlo,
1731     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1732     const llvm_ir::IrArray::Index& index) {
1733   TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
1734                       operand_to_generator.at(hlo->operand(0))(index));
1735   TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
1736                       operand_to_generator.at(hlo->operand(1))(index));
1737   TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
1738                       operand_to_generator.at(hlo->operand(2))(index));
1739   PrimitiveType prim_type = hlo->shape().element_type();
1740   if (primitive_util::IsFloatingPointType(prim_type)) {
1741     return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value, ""), "");
1742   } else if (primitive_util::IsIntegralType(prim_type)) {
1743     bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
1744     return EmitIntegralMin(
1745         max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed);
1746   } else {
1747     return Unimplemented("Clamp unimplemented for %s",
1748                          PrimitiveType_Name(prim_type));
1749   }
1750 }
1751 
EmitElementalConcatenate(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & target_index)1752 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
1753     const HloInstruction* hlo,
1754     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1755     const llvm_ir::IrArray::Index& target_index) {
1756   const int64 concat_dim = hlo->dimensions(0);
1757   auto source_index = target_index;
1758 
1759   llvm::BasicBlock* init_block = b_->GetInsertBlock();
1760 
1761   // A terminator should be present iff we're emitting code
1762   // into the middle (as opposed to the end) of a basic block.
1763   CHECK_EQ(b_->GetInsertPoint() == init_block->end(),
1764            init_block->getTerminator() == nullptr);
1765 
1766   llvm::BasicBlock* exit_block;
1767   if (b_->GetInsertPoint() == init_block->end()) {
1768     exit_block = llvm_ir::CreateBasicBlock(
1769         /*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
1770   } else {
1771     exit_block =
1772         init_block->splitBasicBlock(b_->GetInsertPoint(), IrName(hlo, "merge"));
1773     init_block->getTerminator()->eraseFromParent();
1774   }
1775 
1776   llvm_ir::SetToFirstInsertPoint(exit_block, b_);
1777   llvm::PHINode* output =
1778       PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
1779           hlo->operands().size());
1780   auto prior_insert_point = b_->GetInsertPoint();
1781 
1782   b_->SetInsertPoint(init_block);
1783 
1784   // Assign a unique id for each *different* operand, and count how often each
1785   // operand is used. If all operands are different, the usage count will be 1
1786   // for each operand.
1787   absl::flat_hash_map<const HloInstruction*, int64> to_unique_operand_id;
1788   std::vector<int64> operand_usage_count;
1789   for (const auto* operand : hlo->operands()) {
1790     if (to_unique_operand_id.contains(operand)) {
1791       ++operand_usage_count[to_unique_operand_id[operand]];
1792     } else {
1793       int64 unique_operand_id = to_unique_operand_id.size();
1794       to_unique_operand_id[operand] = unique_operand_id;
1795       operand_usage_count.push_back(1);
1796     }
1797   }
1798 
1799   // To avoid that we emit the same operand more than once, we create one basic
1800   // block for each *different* operand with a PHI node for the different source
1801   // index inputs.
1802   std::vector<llvm::BasicBlock*> emit_operand_blocks(
1803       to_unique_operand_id.size(), nullptr);
1804   std::vector<llvm::PHINode*> source_index_phis(to_unique_operand_id.size(),
1805                                                 nullptr);
1806   for (const auto* operand : hlo->operands()) {
1807     int64 operand_id = to_unique_operand_id[operand];
1808     if (emit_operand_blocks[operand_id] != nullptr) {
1809       continue;
1810     }
1811 
1812     emit_operand_blocks[operand_id] = llvm_ir::CreateBasicBlock(
1813         exit_block, StrCat("concat_index_from_operand_id", operand_id), b_);
1814     auto saved_insert_point = b_->GetInsertPoint();
1815     llvm_ir::SetToFirstInsertPoint(emit_operand_blocks[operand_id], b_);
1816     source_index_phis[operand_id] =
1817         PHI(source_index.GetType(), operand_usage_count[operand_id]);
1818     std::vector<llvm::Value*> operand_multi_index = source_index.multidim();
1819     operand_multi_index[concat_dim] =
1820         NSWSub(operand_multi_index[concat_dim], source_index_phis[operand_id]);
1821 
1822     // Create the terminator of the block before calling operand generators,
1823     // because they require non-degenerate basic blocks.
1824     b_->SetInsertPoint(llvm::BranchInst::Create(
1825         exit_block, /*InsertAtEnd=*/emit_operand_blocks[operand_id]));
1826     llvm_ir::IrArray::Index operand_index(operand_multi_index, operand->shape(),
1827                                           source_index.GetType());
1828     TF_ASSIGN_OR_RETURN(llvm::Value * value,
1829                         operand_to_generator.at(operand)(operand_index));
1830     output->addIncoming(value, b_->GetInsertBlock());
1831     b_->SetInsertPoint(init_block, saved_insert_point);
1832   }
1833 
1834   int64 concat_dim_size = 0;
1835   for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
1836        ++operand_idx) {
1837     const HloInstruction* operand = hlo->operand(operand_idx);
1838     auto false_block = llvm_ir::CreateBasicBlock(
1839         exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_);
1840     int64 operand_id = to_unique_operand_id[operand];
1841     source_index_phis[operand_id]->addIncoming(
1842         source_index.GetConstantWithIndexType(concat_dim_size),
1843         b_->GetInsertBlock());
1844     concat_dim_size += operand->shape().dimensions(concat_dim);
1845     CondBr(ICmpULT(source_index[concat_dim],
1846                    source_index.GetConstantWithIndexType(concat_dim_size)),
1847            emit_operand_blocks[operand_id], false_block);
1848 
1849     // Subtract the size of the concat dimension of the current operand
1850     // from the source index.
1851     b_->SetInsertPoint(false_block);
1852   }
1853 
1854   Unreachable();
1855   b_->SetInsertPoint(exit_block, prior_insert_point);
1856   return output;
1857 }
1858 
EmitElementalDynamicSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1859 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
1860     const HloInstruction* hlo,
1861     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1862     const llvm_ir::IrArray::Index& index) {
1863   // Emit IR to read dynamic start indices from hlo->operand(1).
1864   const HloInstruction* input_hlo = hlo->operand(0);
1865   const int64 rank = input_hlo->shape().rank();
1866   // Use the same index type for all tensor accesses in the same kernel.
1867   llvm::Type* index_type = index.GetType();
1868   std::vector<llvm::Value*> slice_start_multi_index(rank);
1869   for (int64 i = 0; i < rank; ++i) {
1870     auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
1871       return llvm::ConstantInt::get(index_type, c);
1872     };
1873     llvm_ir::IrArray::Index zero_index(index_type);
1874     TF_ASSIGN_OR_RETURN(
1875         llvm::Value * start_index_value,
1876         operand_to_generator.at(hlo->operand(1 + i))(zero_index));
1877 
1878     // Clamp the start index so that the sliced portion fits in the operand:
1879     // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
1880     start_index_value = SExtOrTrunc(start_index_value, index_type);
1881     int64 largest_valid_start_index =
1882         input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
1883     CHECK_GE(largest_valid_start_index, 0);
1884 
1885     bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
1886     start_index_value = EmitIntegralMin(
1887         index_typed_const(largest_valid_start_index),
1888         EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
1889         is_signed);
1890 
1891     start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
1892     slice_start_multi_index[i] = start_index_value;
1893   }
1894 
1895   std::vector<llvm::Value*> input_multi_index(rank);
1896   for (int64 i = 0; i < rank; ++i) {
1897     // Emit IR which computes:
1898     //   input_index = start_index + offset_index
1899     input_multi_index[i] = Add(slice_start_multi_index[i], index[i]);
1900   }
1901   llvm_ir::IrArray::Index input_index(input_multi_index, input_hlo->shape(),
1902                                       index_type);
1903   return operand_to_generator.at(input_hlo)(input_index);
1904 }
1905 
EmitElementalGather(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)1906 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
1907     const HloInstruction* hlo,
1908     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
1909     const llvm_ir::IrArray::Index& index) {
1910   const Shape& operand_shape = hlo->operand(0)->shape();
1911   const Shape& indices_shape = hlo->operand(1)->shape();
1912   const Shape& output_shape = hlo->shape();
1913 
1914   const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers();
1915 
1916   const llvm_ir::ElementGenerator& operand_generator =
1917       operand_to_generator.at(hlo->operand(0));
1918   const llvm_ir::ElementGenerator& indices_generator =
1919       operand_to_generator.at(hlo->operand(1));
1920 
1921   llvm::Type* index_type = index.GetType();
1922   // This is the index into `operand` that holds the element we want to
1923   // generate.
1924   std::vector<llvm::Value*> operand_multi_index;
1925 
1926   // First copy in the window indices to operand_index. Also collect a mapping
1927   // from operand dimension to output window dimension. Elided window dimensions
1928   // map to -1.
1929   std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
1930   for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
1931        i < e; i++) {
1932     if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
1933       operand_multi_index.push_back(index.GetConstantWithIndexType(0));
1934     } else {
1935       int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++);
1936       operand_to_output_dim[i] = output_window_dim;
1937       operand_multi_index.push_back(index[output_window_dim]);
1938     }
1939   }
1940 
1941   // This is the index of the index vector in the start_indices tensor.
1942   std::vector<llvm::Value*> gather_index_index_components;
1943   {
1944     for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
1945       if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
1946         gather_index_index_components.push_back(index[i]);
1947       }
1948     }
1949 
1950     if (gather_index_index_components.size() !=
1951         indices_shape.dimensions_size()) {
1952       gather_index_index_components.insert(
1953           gather_index_index_components.begin() +
1954               dim_numbers.index_vector_dim(),
1955           nullptr);
1956     }
1957   }
1958 
1959   auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
1960     auto index_component_type = index_component->getType();
1961     auto extended_type = index_component_type->getScalarSizeInBits() >=
1962                                  index_type->getScalarSizeInBits()
1963                              ? index_component_type
1964                              : index_type;
1965     // Possibly extend the value at the beginning to ensure clamping logic stays
1966     // in bounds.
1967     auto maybe_extended_index =
1968         index_component_type != extended_type
1969             ? b_->CreateSExt(index_component, extended_type)
1970             : index_component;
1971     int64 operand_dim = dim_numbers.start_index_map(dim);
1972     int64 output_dim = operand_to_output_dim[operand_dim];
1973     // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
1974     // This means we set the iteration index to 0, so for the purpose of the
1975     // following calculations we can consider the output dimension size to be 1.
1976     int64 output_dim_size =
1977         output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
1978     int64 largest_valid_start_index =
1979         operand_shape.dimensions(operand_dim) - output_dim_size;
1980     CHECK_GE(largest_valid_start_index, 0);
1981 
1982     // Clamp the gather index so that the gather region fits in the operand.
1983     // clamped_index =
1984     //     clamp(gather_dim_component_extended, 0, largest_valid_start_index);
1985     bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
1986     auto clamped_index = EmitIntegralMin(
1987         llvm::ConstantInt::get(extended_type, largest_valid_start_index),
1988         EmitIntegralMax(llvm::ConstantInt::get(extended_type, 0),
1989                         maybe_extended_index, is_signed),
1990         is_signed);
1991     // Truncate at the end to the optimized index size
1992     auto maybe_truncated_clamped_index = extended_type != index_type
1993                                              ? Trunc(clamped_index, index_type)
1994                                              : clamped_index;
1995 
1996     operand_multi_index[operand_dim] =
1997         Add(operand_multi_index[operand_dim], maybe_truncated_clamped_index);
1998   };
1999 
2000   if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
2001     IrArray::Index gather_index_index(gather_index_index_components,
2002                                       indices_shape, index_type);
2003     TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2004                         indices_generator(gather_index_index));
2005     add_to_operand_index(gather_dim_component, 0);
2006   } else {
2007     int64 index_vector_size =
2008         indices_shape.dimensions(dim_numbers.index_vector_dim());
2009     for (int64 i = 0; i < index_vector_size; i++) {
2010       gather_index_index_components[dim_numbers.index_vector_dim()] =
2011           index.GetConstantWithIndexType(i);
2012       IrArray::Index gather_index_index(gather_index_index_components,
2013                                         indices_shape, index_type);
2014       TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
2015                           indices_generator(gather_index_index));
2016       add_to_operand_index(gather_dim_component, i);
2017     }
2018   }
2019   IrArray::Index operand_index(operand_multi_index, operand_shape, index_type);
2020   return operand_generator(operand_index);
2021 }
2022 
EmitElementalDynamicUpdateSlice(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2023 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
2024     const HloInstruction* hlo,
2025     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2026     const llvm_ir::IrArray::Index& index) {
2027   const HloInstruction* input_hlo = hlo->operand(0);
2028   const HloInstruction* update_hlo = hlo->operand(1);
2029   const HloInstruction* start_hlo = hlo->operand(2);
2030   // Calculate slice start/end indices.
2031   const int64 rank = input_hlo->shape().rank();
2032   std::vector<llvm::Value*> slice_start_multi_index(rank);
2033   std::vector<llvm::Value*> slice_limit_multi_index(rank);
2034   // Slice intersection gathers (ANDs) conditions on all ranks for which
2035   // 'input' is set to 'update'
2036   llvm::Value* slice_intersection = b_->getTrue();
2037 
2038   for (int64 i = 0; i < rank; ++i) {
2039     llvm::Type* index_type = index[0]->getType();
2040     auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2041       return llvm::ConstantInt::get(index_type, c);
2042     };
2043 
2044     llvm_ir::IrArray::Index zero_index(index_type);
2045     TF_ASSIGN_OR_RETURN(
2046         llvm::Value * start_index_value,
2047         operand_to_generator.at(hlo->operand(2 + i))(zero_index));
2048 
2049     // Clamp the start index so that the update region fits in the operand.
2050     // start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
2051     start_index_value = SExtOrTrunc(start_index_value, index_type);
2052     llvm::Value* update_dim_size =
2053         index_typed_const(update_hlo->shape().dimensions(i));
2054     int64 largest_valid_start_index =
2055         input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
2056     CHECK_GE(largest_valid_start_index, 0);
2057 
2058     bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
2059     start_index_value = EmitIntegralMin(
2060         index_typed_const(largest_valid_start_index),
2061         EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
2062         is_signed);
2063 
2064     start_index_value->setName(IrName(hlo, StrCat("start_idx", i)));
2065     slice_start_multi_index[i] = start_index_value;
2066     slice_limit_multi_index[i] =
2067         Add(slice_start_multi_index[i], update_dim_size);
2068 
2069     slice_intersection =
2070         And(slice_intersection, ICmpSGE(index[i], slice_start_multi_index[i]),
2071             "slice_intersection");
2072     slice_intersection =
2073         And(slice_intersection, ICmpSLT(index[i], slice_limit_multi_index[i]),
2074             "slice_intersection");
2075   }
2076 
2077   // Emit:
2078   // if (slice_intersection) -> return data from 'update'.
2079   // else                    -> return data from 'input'.
2080   llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2081       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2082       "ret_value_addr", b_);
2083   llvm_ir::LlvmIfData if_data =
2084       llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
2085 
2086   // Handle true BB (return data from 'update')
2087   SetToFirstInsertPoint(if_data.true_block, b_);
2088   // Compute update index for intersection case.
2089   std::vector<llvm::Value*> update_multi_index(rank);
2090   for (int64 i = 0; i < rank; ++i) {
2091     update_multi_index[i] = Sub(index[i], slice_start_multi_index[i]);
2092   }
2093   llvm_ir::IrArray::Index update_index(update_multi_index, update_hlo->shape(),
2094                                        index.GetType());
2095   TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
2096                       operand_to_generator.at(update_hlo)(update_index));
2097   Store(true_value, ret_value_addr);
2098 
2099   // Handle false BB (return data from 'input')
2100   SetToFirstInsertPoint(if_data.false_block, b_);
2101   TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
2102                       operand_to_generator.at(input_hlo)(index));
2103   Store(false_value, ret_value_addr);
2104 
2105   SetToFirstInsertPoint(if_data.after_block, b_);
2106   return Load(ret_value_addr);
2107 }
2108 
EmitElementalPad(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & padded_index)2109 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
2110     const HloInstruction* hlo,
2111     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2112     const llvm_ir::IrArray::Index& padded_index) {
2113   std::vector<llvm::Value*> multi_index = padded_index.multidim();
2114   llvm::Value* in_bounds = b_->getTrue();
2115   for (size_t i = 0; i < multi_index.size(); ++i) {
2116     auto index_typed_const = [=](int64 n) {
2117       return padded_index.GetConstantWithIndexType(n);
2118     };
2119     const auto& pad_dim = hlo->padding_config().dimensions(i);
2120     multi_index[i] =
2121         Sub(multi_index[i], index_typed_const(pad_dim.edge_padding_low()));
2122     in_bounds = And(in_bounds, ICmpSGE(multi_index[i], index_typed_const(0)),
2123                     "in_bounds");
2124     in_bounds =
2125         And(in_bounds,
2126             ICmpEQ(index_typed_const(0),
2127                    URem(multi_index[i],
2128                         index_typed_const(pad_dim.interior_padding() + 1))),
2129             "in_bounds");
2130     multi_index[i] =
2131         SDiv(multi_index[i], index_typed_const(pad_dim.interior_padding() + 1));
2132     in_bounds =
2133         And(in_bounds,
2134             ICmpSLT(multi_index[i],
2135                     index_typed_const(hlo->operand(0)->shape().dimensions(i))),
2136             "in_bounds");
2137   }
2138 
2139   // if (in_bounds) {
2140   //   ret_value = operand0[index];  // source
2141   // } else {
2142   //   ret_value = *operand1;        // padding
2143   // }
2144   llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2145       llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
2146       "pad_result_addr", b_);
2147   llvm_ir::LlvmIfData if_data =
2148       llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2149   SetToFirstInsertPoint(if_data.true_block, b_);
2150   llvm_ir::IrArray::Index index(multi_index, hlo->operand(0)->shape(),
2151                                 padded_index.GetType());
2152   TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2153                       operand_to_generator.at(hlo->operand(0))(index));
2154   Store(operand_value, ret_value_addr);
2155 
2156   SetToFirstInsertPoint(if_data.false_block, b_);
2157   TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
2158                       operand_to_generator.at(hlo->operand(1))(
2159                           IrArray::Index(index.GetType())));
2160   Store(padding_value, ret_value_addr);
2161 
2162   SetToFirstInsertPoint(if_data.after_block, b_);
2163   // Don't create phi(operand_value, padding_value) here, because invoking
2164   // operand_to_generator may create new basic blocks, making the parent
2165   // of operand_value or padding_value no longer a predecessor of
2166   // if_data.after_block.
2167   return Load(ret_value_addr);
2168 }
2169 
EmitElementalDot(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & dot_result_index)2170 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
2171     const HloInstruction* hlo,
2172     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2173     const llvm_ir::IrArray::Index& dot_result_index) {
2174   auto lhs_generator = operand_to_generator.at(hlo->operand(0));
2175   auto rhs_generator = operand_to_generator.at(hlo->operand(1));
2176 
2177   const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers();
2178   int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0);
2179   int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0);
2180 
2181   int64 contracted_dim_size =
2182       hlo->operand(0)->shape().dimensions(lhs_contracting_dim);
2183   int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
2184   int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
2185 
2186   llvm::Type* index_type = dot_result_index.GetType();
2187   auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2188     return llvm::ConstantInt::get(index_type, c);
2189   };
2190 
2191   std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
2192       IrName(hlo, "inner"), index_typed_const(0),
2193       index_typed_const(contracted_dim_size), index_typed_const(1), b_);
2194 
2195   SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
2196   PrimitiveType primitive_type = hlo->shape().element_type();
2197   llvm::Type* primitive_type_llvm =
2198       llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
2199   llvm::Value* accumulator_alloca =
2200       llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
2201   Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca);
2202 
2203   SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
2204 
2205   // This is the inner reduction loop for a dot operation that produces
2206   // one element in the output.  If the operands to the dot operation have
2207   // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
2208   // Given an output index [a,b,c,d,e] in the result, we compute:
2209   //   sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
2210 
2211   std::vector<llvm::Value*> lhs_multi_index, rhs_multi_index;
2212   for (int64 i = 0; i < lhs_dims - 1; i++) {
2213     lhs_multi_index.push_back(dot_result_index[i]);
2214   }
2215   lhs_multi_index.insert(lhs_multi_index.begin() + lhs_contracting_dim,
2216                          inner_loop->GetIndVarValue());
2217   IrArray::Index lhs_index(lhs_multi_index, hlo->operand(0)->shape(),
2218                            index_type);
2219 
2220   int64 num_batch_dims = dim_numbers.rhs_batch_dimensions_size();
2221   for (int64 i = 0; i < num_batch_dims; i++) {
2222     rhs_multi_index.push_back(
2223         dot_result_index[dim_numbers.rhs_batch_dimensions(i)]);
2224   }
2225   for (int64 i = 0; i < rhs_dims - 1 - num_batch_dims; i++) {
2226     rhs_multi_index.push_back(dot_result_index[lhs_dims - 1 + i]);
2227   }
2228   rhs_multi_index.insert(rhs_multi_index.begin() + rhs_contracting_dim,
2229                          inner_loop->GetIndVarValue());
2230   IrArray::Index rhs_index(rhs_multi_index, hlo->operand(1)->shape(),
2231                            index_type);
2232 
2233   llvm::Value* current_accumulator = Load(accumulator_alloca);
2234   TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
2235   TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
2236   llvm::Value* next_accumulator =
2237       EmitMulAdd(lhs_value, rhs_value, current_accumulator, primitive_type);
2238   Store(next_accumulator, accumulator_alloca);
2239 
2240   SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
2241   return Load(accumulator_alloca);
2242 }
2243 
MakeElementGenerator(const HloInstruction * hlo,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator)2244 llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
2245     const HloInstruction* hlo,
2246     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) {
2247   switch (hlo->opcode()) {
2248     case HloOpcode::kAbs:
2249     case HloOpcode::kRoundNearestAfz:
2250     case HloOpcode::kCeil:
2251     case HloOpcode::kClz:
2252     case HloOpcode::kConvert:
2253     case HloOpcode::kBitcastConvert:
2254     case HloOpcode::kCos:
2255     case HloOpcode::kExp:
2256     case HloOpcode::kExpm1:
2257     case HloOpcode::kFloor:
2258     case HloOpcode::kImag:
2259     case HloOpcode::kIsFinite:
2260     case HloOpcode::kLog:
2261     case HloOpcode::kLog1p:
2262     case HloOpcode::kNegate:
2263     case HloOpcode::kNot:
2264     case HloOpcode::kPopulationCount:
2265     case HloOpcode::kReal:
2266     case HloOpcode::kRsqrt:
2267     case HloOpcode::kSign:
2268     case HloOpcode::kSin:
2269     case HloOpcode::kSqrt:
2270     case HloOpcode::kCbrt:
2271     case HloOpcode::kTanh:
2272       return [this, hlo, &operand_to_generator](
2273                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2274         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2275                             operand_to_generator.at(hlo->operand(0))(index));
2276         return EmitUnaryOp(hlo, operand_value);
2277       };
2278     case HloOpcode::kAdd:
2279     case HloOpcode::kAnd:
2280     case HloOpcode::kAtan2:
2281     case HloOpcode::kCompare:
2282     case HloOpcode::kComplex:
2283     case HloOpcode::kDivide:
2284     case HloOpcode::kMaximum:
2285     case HloOpcode::kMinimum:
2286     case HloOpcode::kMultiply:
2287     case HloOpcode::kOr:
2288     case HloOpcode::kXor:
2289     case HloOpcode::kPower:
2290     case HloOpcode::kRemainder:
2291     case HloOpcode::kShiftLeft:
2292     case HloOpcode::kShiftRightArithmetic:
2293     case HloOpcode::kShiftRightLogical:
2294     case HloOpcode::kSubtract:
2295       return [this, hlo, &operand_to_generator](
2296                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2297         const HloInstruction* lhs = hlo->operand(0);
2298         const HloInstruction* rhs = hlo->operand(1);
2299         TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
2300                             operand_to_generator.at(lhs)(index));
2301         TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
2302                             operand_to_generator.at(rhs)(index));
2303         return EmitBinaryOp(hlo, lhs_value, rhs_value);
2304       };
2305     case HloOpcode::kSelect:
2306       return [this, hlo, &operand_to_generator](
2307                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2308         return EmitElementalSelect(hlo, operand_to_generator, index);
2309       };
2310     case HloOpcode::kClamp:
2311       return [this, hlo, &operand_to_generator](
2312                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2313         return EmitElementalClamp(hlo, operand_to_generator, index);
2314       };
2315     case HloOpcode::kReducePrecision:
2316       return [this, hlo, &operand_to_generator](
2317                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2318         TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2319                             operand_to_generator.at(hlo->operand(0))(index));
2320         return EmitReducePrecision(hlo, operand_value);
2321       };
2322     case HloOpcode::kConcatenate:
2323       return [this, hlo, &operand_to_generator](
2324                  const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
2325         return EmitElementalConcatenate(hlo, operand_to_generator,
2326                                         target_index);
2327       };
2328     case HloOpcode::kReverse:
2329       return [this, hlo, &operand_to_generator](
2330                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2331         const HloInstruction* operand = hlo->operand(0);
2332         std::vector<llvm::Value*> source_multi_index = target_index.multidim();
2333         for (int64 dim : hlo->dimensions()) {
2334           source_multi_index[dim] = Sub(target_index.GetConstantWithIndexType(
2335                                             hlo->shape().dimensions(dim) - 1),
2336                                         target_index[dim]);
2337         }
2338         llvm_ir::IrArray::Index source_index(
2339             source_multi_index, operand->shape(), target_index.GetType());
2340         return operand_to_generator.at(operand)(source_index);
2341       };
2342     case HloOpcode::kBroadcast:
2343       return [this, hlo, &operand_to_generator](
2344                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2345         const HloInstruction* operand = hlo->operand(0);
2346         // The `dimensions` member of the broadcast instruction maps from
2347         // input dimensions to output dimensions.
2348         return operand_to_generator.at(operand)(
2349             target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
2350                                                 hlo->dimensions(), b_));
2351       };
2352     case HloOpcode::kIota:
2353       return [this, hlo](
2354                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2355         auto* iota = Cast<HloIotaInstruction>(hlo);
2356         PrimitiveType element_type = iota->shape().element_type();
2357         IrArray::Index elem_index =
2358             iota->shape().rank() > 1
2359                 ? target_index.SourceIndexOfBroadcast(
2360                       iota->shape(),
2361                       ShapeUtil::MakeShapeWithDescendingLayout(
2362                           element_type,
2363                           {iota->shape().dimensions(iota->iota_dimension())}),
2364                       {iota->iota_dimension()}, b_)
2365                 : target_index;
2366         llvm::Value* elem_index_linear = elem_index.linear();
2367         if (elem_index_linear == nullptr) {
2368           std::vector<int64> iota_bound = {
2369               iota->shape().dimensions(iota->iota_dimension())};
2370           elem_index_linear = elem_index.Linearize(iota_bound, b_);
2371         }
2372         Shape component_shape =
2373             ShapeUtil::ElementIsComplex(iota->shape())
2374                 ? ShapeUtil::ComplexComponentShape(iota->shape())
2375                 : iota->shape();
2376         PrimitiveType component_element_type = component_shape.element_type();
2377         llvm::Value* iota_result;
2378         if (primitive_util::IsIntegralType(component_element_type)) {
2379           iota_result = b_->CreateIntCast(
2380               elem_index_linear,
2381               llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
2382               /*isSigned=*/false);
2383         } else {
2384           TF_RET_CHECK(
2385               primitive_util::IsFloatingPointType(component_element_type))
2386               << component_element_type;
2387           llvm::Type* float_ir_type;
2388           if (component_element_type == BF16) {
2389             float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_);
2390           } else {
2391             float_ir_type =
2392                 llvm_ir::PrimitiveTypeToIrType(component_element_type, module_);
2393           }
2394           llvm::Value* float_val =
2395               b_->CreateUIToFP(elem_index_linear, float_ir_type);
2396           if (component_element_type == BF16) {
2397             TF_ASSIGN_OR_RETURN(iota_result, EmitF32ToBF16(float_val, b_));
2398           } else {
2399             iota_result = float_val;
2400           }
2401         }
2402         if (ShapeUtil::ElementIsComplex(iota->shape())) {
2403           return EmitComposeComplex(iota, iota_result, nullptr);
2404         } else {
2405           return iota_result;
2406         }
2407       };
2408     case HloOpcode::kSlice:
2409       return [this, hlo, &operand_to_generator](
2410                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2411         IrArray::Index sliced_index = index.SourceIndexOfSlice(
2412             /*operand_shape=*/hlo->operand(0)->shape(),
2413             /*starts=*/hlo->slice_starts(),
2414             /*strides=*/hlo->slice_strides(), /*builder=*/b_);
2415         return operand_to_generator.at(hlo->operand(0))(sliced_index);
2416       };
2417     case HloOpcode::kDynamicSlice:
2418       return [this, hlo, &operand_to_generator](
2419                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2420         return EmitElementalDynamicSlice(hlo, operand_to_generator, index);
2421       };
2422 
2423     case HloOpcode::kGather:
2424       return [this, hlo, &operand_to_generator](
2425                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2426         return EmitElementalGather(hlo, operand_to_generator, index);
2427       };
2428     case HloOpcode::kDynamicUpdateSlice:
2429       return [this, hlo, &operand_to_generator](
2430                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2431         return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator,
2432                                                index);
2433       };
2434     case HloOpcode::kBitcast:
2435       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2436                ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2437       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2438         const HloInstruction* operand = hlo->operand(0);
2439         return operand_to_generator.at(operand)(
2440             index.SourceIndexOfBitcast(hlo->shape(), operand->shape(), b_));
2441       };
2442     case HloOpcode::kReshape:
2443       CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
2444                ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
2445       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2446         const HloInstruction* operand = hlo->operand(0);
2447         return operand_to_generator.at(operand)(
2448             index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
2449       };
2450     case HloOpcode::kCopy:
2451       return [hlo, &operand_to_generator](
2452                  const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
2453         IrArray::Index source_index(target_index.multidim(),
2454                                     hlo->operand(0)->shape(),
2455                                     target_index.GetType());
2456         TF_ASSIGN_OR_RETURN(
2457             llvm::Value * operand_value,
2458             operand_to_generator.at(hlo->operand(0))(source_index));
2459         return operand_value;
2460       };
2461     case HloOpcode::kTranspose:
2462       return [this, hlo,
2463               &operand_to_generator](const IrArray::Index& target_index) {
2464         return operand_to_generator.at(hlo->operand(0))(
2465             target_index.SourceIndexOfTranspose(
2466                 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions()));
2467       };
2468     case HloOpcode::kPad:
2469       return [this, hlo, &operand_to_generator](
2470                  const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
2471         return EmitElementalPad(hlo, operand_to_generator, padded_index);
2472       };
2473 
2474     case HloOpcode::kDot:
2475       return [this, hlo,
2476               &operand_to_generator](const IrArray::Index& dot_result_index)
2477                  -> StatusOr<llvm::Value*> {
2478         return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
2479       };
2480     case HloOpcode::kMap:
2481       return [this, hlo, &operand_to_generator](
2482                  const IrArray::Index& index) -> StatusOr<llvm::Value*> {
2483         std::vector<llvm::Value*> operands;
2484         for (int i = 0; i < hlo->operand_count(); i++) {
2485           TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
2486                               operand_to_generator.at(hlo->operand(i))(index));
2487           operands.push_back(operand_value);
2488         }
2489         return EmitElementalMap(Cast<HloMapInstruction>(hlo), operands);
2490       };
2491     case HloOpcode::kReduceWindow:
2492       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2493         auto reduce_window_instr = Cast<HloReduceWindowInstruction>(hlo);
2494         std::vector<llvm_ir::ElementGenerator> input_generators;
2495         for (const HloInstruction* instr :
2496              reduce_window_instr->input_arrays()) {
2497           input_generators.push_back(operand_to_generator.at(instr));
2498         }
2499 
2500         std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2501         for (const HloInstruction* instr : reduce_window_instr->init_values()) {
2502           initial_value_generators.push_back(operand_to_generator.at(instr));
2503         }
2504         return EmitElementalReduceWindow(
2505             Cast<HloReduceWindowInstruction>(hlo), std::move(input_generators),
2506             std::move(initial_value_generators), index);
2507       };
2508     case HloOpcode::kReduce:
2509       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2510         auto reduce_instr = Cast<HloReduceInstruction>(hlo);
2511         std::vector<llvm_ir::ElementGenerator> input_generators;
2512         for (const HloInstruction* instr : reduce_instr->inputs()) {
2513           input_generators.push_back(operand_to_generator.at(instr));
2514         }
2515 
2516         std::vector<llvm_ir::ElementGenerator> initial_value_generators;
2517         for (const HloInstruction* instr : reduce_instr->init_values()) {
2518           initial_value_generators.push_back(operand_to_generator.at(instr));
2519         }
2520         return EmitElementalReduce(reduce_instr, std::move(input_generators),
2521                                    std::move(initial_value_generators), index);
2522       };
2523     case HloOpcode::kConvolution:
2524       return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
2525         return EmitConvolution(hlo, operand_to_generator, index);
2526       };
2527     default:
2528       return [hlo](const IrArray::Index& index) {
2529         return Unimplemented("Unhandled opcode for elemental IR emission: %s",
2530                              HloOpcodeString(hlo->opcode()));
2531       };
2532   }
2533 }
2534 
EmitExtractReal(llvm::Value * value)2535 llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) {
2536   return ExtractValue(value, {0});
2537 }
2538 
EmitExtractImag(llvm::Value * value)2539 llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) {
2540   return ExtractValue(value, {1});
2541 }
2542 
EmitComposeComplex(const HloInstruction * op,llvm::Value * real,llvm::Value * imag)2543 llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
2544                                                     llvm::Value* real,
2545                                                     llvm::Value* imag) {
2546   auto cplx_type =
2547       llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
2548   auto complex =
2549       InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0});
2550   if (imag != nullptr) {
2551     complex = InsertValue(complex, imag, {1});
2552   }
2553   return complex;
2554 }
2555 
EmitMulAdd(llvm::Value * lhs,llvm::Value * rhs,llvm::Value * accumulator,xla::PrimitiveType primitive_type)2556 llvm::Value* ElementalIrEmitter::EmitMulAdd(llvm::Value* lhs, llvm::Value* rhs,
2557                                             llvm::Value* accumulator,
2558                                             xla::PrimitiveType primitive_type) {
2559   if (primitive_util::IsComplexType(primitive_type)) {
2560     llvm::Value* product_real =
2561         FSub(FMul(EmitExtractReal(lhs), EmitExtractReal(rhs)),
2562              FMul(EmitExtractImag(lhs), EmitExtractImag(rhs)));
2563     llvm::Value* product_imag =
2564         FAdd(FMul(EmitExtractReal(lhs), EmitExtractImag(rhs)),
2565              FMul(EmitExtractImag(lhs), EmitExtractReal(rhs)));
2566     llvm::Value* next_accumulator = InsertValue(
2567         accumulator, FAdd(EmitExtractReal(accumulator), product_real), {0});
2568     return InsertValue(next_accumulator,
2569                        FAdd(EmitExtractImag(accumulator), product_imag), {1});
2570   } else if (primitive_util::IsFloatingPointType(primitive_type)) {
2571     return FAdd(accumulator, FPCast(FMul(lhs, rhs), accumulator->getType()));
2572   } else if (primitive_type == PRED) {
2573     return Or(accumulator, And(lhs, rhs));
2574   }
2575   return Add(accumulator, Mul(lhs, rhs));
2576 }
2577 
EmitElementalMap(const HloMapInstruction * map_instr,absl::Span<llvm::Value * const> elemental_operands)2578 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalMap(
2579     const HloMapInstruction* map_instr,
2580     absl::Span<llvm::Value* const> elemental_operands) {
2581   TF_ASSIGN_OR_RETURN(
2582       std::vector<llvm::Value*> values,
2583       EmitThreadLocalCall(*map_instr->to_apply(), elemental_operands,
2584                           llvm_ir::IrName(map_instr)));
2585   CHECK_EQ(values.size(), 1);
2586   return values[0];
2587 }
2588 
EmitElementalReduceWindow(const HloReduceWindowInstruction * reduce_window,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2589 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
2590     const HloReduceWindowInstruction* reduce_window,
2591     std::vector<llvm_ir::ElementGenerator> input_generators,
2592     std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2593     const llvm_ir::IrArray::Index& index) {
2594   // Pseudocode:
2595   // for each index I in output
2596   //   value = init_value
2597   //   for each index W in window
2598   //     for each dimension i from 0 to rank - 1
2599   //       (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
2600   //     if I in bounds of input
2601   //       value = function(value, input[I])
2602   //     output[O] = value
2603   int64 input_count = reduce_window->input_count();
2604   std::vector<PrimitiveType> operand_element_types;
2605   std::vector<llvm::Type*> accum_types;
2606   std::vector<llvm::Value*> accum_ptrs;
2607   for (int64 operand_index = 0; operand_index < input_count; ++operand_index) {
2608     auto operand = reduce_window->input_arrays()[operand_index];
2609     PrimitiveType operand_element_type = operand->shape().element_type();
2610     operand_element_types.push_back(operand_element_type);
2611     llvm::Type* llvm_type =
2612         llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_);
2613     accum_types.push_back(llvm_type);
2614     llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
2615         llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
2616         "reduce_window_accum_ptr", b_);
2617     accum_ptrs.push_back(accum_ptr);
2618     {
2619       auto initial_value_generator = initial_value_generators[operand_index];
2620       TF_ASSIGN_OR_RETURN(
2621           llvm::Value* const init_value,
2622           initial_value_generator(llvm_ir::IrArray::Index(index.GetType())));
2623       Store(init_value, accum_ptr);
2624     }
2625   }
2626 
2627   llvm::Type* index_type = index.GetType();
2628   auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
2629     return index.GetConstantWithIndexType(c);
2630   };
2631 
2632   const Window& window = reduce_window->window();
2633   llvm_ir::ForLoopNest loops(IrName(reduce_window), b_, index_type);
2634   std::vector<int64> window_size;
2635   for (const auto& dim : window.dimensions()) {
2636     window_size.push_back(dim.size());
2637   }
2638   const IrArray::Index window_index = loops.AddLoopsForShape(
2639       ShapeUtil::MakeShape(operand_element_types[0], window_size), "window");
2640   CHECK_EQ(window_index.size(), index.size());
2641 
2642   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
2643 
2644   std::vector<llvm::Value*> input_multi_index(index.size());
2645   llvm::Value* in_bounds = b_->getInt1(true);
2646   for (size_t i = 0; i < index.size(); ++i) {
2647     llvm::Value* stridden_index =
2648         NSWMul(index[i], index_typed_const(window.dimensions(i).stride()));
2649     input_multi_index[i] = NSWSub(
2650         NSWAdd(
2651             stridden_index,
2652             NSWMul(window_index[i],
2653                    index_typed_const(window.dimensions(i).window_dilation()))),
2654         index_typed_const(window.dimensions(i).padding_low()));
2655 
2656     // We need to verify that we are not in the dilated base area.
2657     llvm::Value* dilation_condition =
2658         ICmpEQ(SRem(input_multi_index[i],
2659                     index_typed_const(window.dimensions(i).base_dilation())),
2660                index_typed_const(0));
2661     in_bounds = And(in_bounds, dilation_condition);
2662 
2663     // Apply base dilation to the index.
2664     input_multi_index[i] =
2665         SDiv(input_multi_index[i],
2666              index_typed_const(window.dimensions(i).base_dilation()));
2667 
2668     // We must check whether 0 <= input_multi_index[i] < bound, as
2669     // otherwise we are in the pad and so can skip the computation. This
2670     // comparison is equivalent to the unsigned comparison
2671     // input_multi_index[i] < bound, as a negative value wraps to a large
2672     // positive value.
2673     in_bounds = And(
2674         in_bounds,
2675         ICmpULT(input_multi_index[i],
2676                 index_typed_const(
2677                     reduce_window->input_arrays()[0]->shape().dimensions(i))));
2678   }
2679 
2680   llvm_ir::LlvmIfData if_data =
2681       llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
2682   SetToFirstInsertPoint(if_data.true_block, b_);
2683 
2684   // We are not in pad, so do the computation.
2685   std::vector<llvm::Value*> input_values(reduce_window->operand_count());
2686   IrArray::Index input_index(
2687       input_multi_index, reduce_window->input_arrays()[0]->shape(), index_type);
2688   for (int64 operand_idx = 0; operand_idx < input_count; ++operand_idx) {
2689     TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
2690                         input_generators[operand_idx](input_index));
2691     input_values[input_count + operand_idx] = input_value;
2692     input_values[operand_idx] = Load(accum_ptrs[operand_idx]);
2693   }
2694   TF_ASSIGN_OR_RETURN(std::vector<llvm::Value*> accum_values,
2695                       EmitThreadLocalCall(*reduce_window->to_apply(),
2696                                           input_values, "reducer_function"));
2697 
2698   for (int64 operand_idx = 0; operand_idx < accum_values.size();
2699        ++operand_idx) {
2700     Store(accum_values[operand_idx], accum_ptrs[operand_idx]);
2701   }
2702 
2703   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
2704   return EmitAccumResult(accum_ptrs, accum_types,
2705                          reduce_window->shape().IsTuple());
2706 }
2707 
EmitElementalReduce(const HloReduceInstruction * reduce,std::vector<llvm_ir::ElementGenerator> input_generators,std::vector<llvm_ir::ElementGenerator> initial_value_generators,const llvm_ir::IrArray::Index & index)2708 StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
2709     const HloReduceInstruction* reduce,
2710     std::vector<llvm_ir::ElementGenerator> input_generators,
2711     std::vector<llvm_ir::ElementGenerator> initial_value_generators,
2712     const llvm_ir::IrArray::Index& index) {
2713   const Shape& out_shape = reduce->shape();
2714   bool is_variadic = !out_shape.IsArray();
2715   int accumulators_count = 1;
2716   if (is_variadic) {
2717     CHECK(out_shape.IsTuple());
2718     accumulators_count = out_shape.tuple_shapes_size();
2719   }
2720 
2721   absl::Span<const int64> reduced_dimensions(reduce->dimensions());
2722 
2723   std::vector<llvm::Value*> accumulator_addrs;
2724   std::vector<llvm::Type*> accumulator_types;
2725   llvm::Type* index_type = index.GetType();
2726   for (int i = 0; i < accumulators_count; i++) {
2727     const Shape& element_shape =
2728         is_variadic ? out_shape.tuple_shapes(i) : out_shape;
2729     PrimitiveType accumulator_type = element_shape.element_type();
2730     llvm::Type* accumulator_llvm_type =
2731         llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_);
2732     accumulator_types.push_back(accumulator_llvm_type);
2733 
2734     // Initialize an accumulator with init_value.
2735     llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
2736         accumulator_llvm_type, "accumulator_" + std::to_string(i), b());
2737     TF_ASSIGN_OR_RETURN(
2738         llvm::Value* const init_value,
2739         initial_value_generators[i](llvm_ir::IrArray::Index(index_type)));
2740     Store(init_value, accumulator_addr);
2741     accumulator_addrs.push_back(accumulator_addr);
2742   }
2743 
2744   // The enclosing loops go over all the target elements. Now we have to compute
2745   // the actual target element. For this, we build a new loop nest to iterate
2746   // over all the reduction dimensions in the argument.
2747   // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
2748   // are placed for each dimension in dimensions, and all the rest are nullptrs.
2749   llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), b(), index_type);
2750   const HloInstruction* arg = reduce->operand(0);
2751   std::vector<llvm::Value*> input_multi_index =
2752       loops.AddLoopsForShapeOnDimensions(arg->shape(), reduced_dimensions,
2753                                          "reduction_dim");
2754 
2755   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
2756 
2757   // Build a full index for the input argument, using input_multi_index as the
2758   // base. In input_multi_index only the reduction dimensions are filled in. We
2759   // fill in the rest of the dimensions with induction Value*s taken from
2760   // 'index' which iterates over the target array.  See the high-level
2761   // description in the XLA documentation for details.
2762   auto it = index.begin();
2763 
2764   for (auto& i : input_multi_index) {
2765     if (i == nullptr) {
2766       i = *it++;
2767     }
2768   }
2769   CHECK(index.end() == it);
2770   llvm_ir::IrArray::Index input_index(input_multi_index, arg->shape(),
2771                                       index_type);
2772 
2773   std::vector<llvm::Value*> reduction_operands;
2774   for (llvm::Value* accum : accumulator_addrs) {
2775     llvm::Value* accum_value = Load(accum);
2776     reduction_operands.push_back(accum_value);
2777   }
2778 
2779   for (int i = 0; i < accumulators_count; i++) {
2780     TF_ASSIGN_OR_RETURN(llvm::Value* const input_element,
2781                         input_generators[i](input_index));
2782     reduction_operands.push_back(input_element);
2783   }
2784 
2785   TF_ASSIGN_OR_RETURN(
2786       std::vector<llvm::Value*> results,
2787       EmitThreadLocalCall(*reduce->to_apply(), reduction_operands,
2788                           "reduce_function"));
2789 
2790   CHECK(results.size() == accumulators_count);
2791   for (int i = 0; i < accumulators_count; i++) {
2792     Store(results[i], accumulator_addrs[i]);
2793   }
2794   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
2795   return EmitAccumResult(accumulator_addrs, accumulator_types, is_variadic);
2796 }
2797 
EmitAccumResult(absl::Span<llvm::Value * const> accumulator_addrs,llvm::ArrayRef<llvm::Type * > accumulator_types,bool is_variadic)2798 StatusOr<llvm::Value*> ElementalIrEmitter::EmitAccumResult(
2799     absl::Span<llvm::Value* const> accumulator_addrs,
2800     llvm::ArrayRef<llvm::Type*> accumulator_types, bool is_variadic) {
2801   TF_RET_CHECK(accumulator_addrs.size() == accumulator_types.size());
2802   if (is_variadic) {
2803     // Emit a structure, as that what the LoopEmitter expects.
2804     llvm::Value* returned_structure = llvm::UndefValue::get(
2805         llvm::StructType::get(b()->getContext(), accumulator_types));
2806     for (int64 i = 0; i < accumulator_addrs.size(); i++) {
2807       llvm::Value* accumulator_value = Load(accumulator_addrs[i]);
2808       returned_structure =
2809           b()->CreateInsertValue(returned_structure, accumulator_value, i);
2810     }
2811     return returned_structure;
2812   } else {
2813     CHECK_EQ(accumulator_addrs.size(), 1);
2814     return Load(accumulator_addrs[0]);
2815   }
2816 }
2817 
EmitConvolution(const HloInstruction * convolution,const ElementalIrEmitter::HloToElementGeneratorMap & operand_to_generator,const llvm_ir::IrArray::Index & index)2818 StatusOr<llvm::Value*> ElementalIrEmitter::EmitConvolution(
2819     const HloInstruction* convolution,
2820     const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
2821     const llvm_ir::IrArray::Index& index) {
2822   const HloInstruction* lhs = convolution->operand(0);
2823   const auto& input_generator = operand_to_generator.at(lhs);
2824   const HloInstruction* rhs = convolution->operand(1);
2825   const auto& kernel_generator = operand_to_generator.at(rhs);
2826   const Window& window = convolution->window();
2827 
2828   const ConvolutionDimensionNumbers& dnums =
2829       convolution->convolution_dimension_numbers();
2830   int num_spatial_dims = dnums.output_spatial_dimensions_size();
2831   std::vector<llvm::Value*> output_spatial(num_spatial_dims);
2832   for (int i = 0; i < num_spatial_dims; ++i) {
2833     output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
2834   }
2835   llvm::Value* output_feature = index[dnums.output_feature_dimension()];
2836   llvm::Value* batch = index[dnums.output_batch_dimension()];
2837 
2838   // We will accumulate the products into this sum to calculate the output entry
2839   // at the given index.
2840   PrimitiveType lhs_element_type = lhs->shape().element_type();
2841   llvm::Type* lhs_llvm_type =
2842       llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
2843   // Upcast the accumulator to F32 from F16 for increased precision.
2844   llvm::Type* accumulator_type =
2845       lhs_element_type == F16 ? b_->getFloatTy() : lhs_llvm_type;
2846   llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
2847       accumulator_type, "convolution_sum_address", b_);
2848   llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type);
2849   Store(constant_zero, sum_address);
2850 
2851   llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), b_);
2852   std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
2853   for (int i = 0; i < num_spatial_dims; ++i) {
2854     kernel_spatial[i] =
2855         loops
2856             .AddLoop(
2857                 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
2858                 absl::StrCat("k", i))
2859             ->GetIndVarValue();
2860   }
2861   llvm::Value* input_feature =
2862       loops
2863           .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
2864                    "iz")
2865           ->GetIndVarValue();
2866 
2867   SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
2868 
2869   // Calculate the spatial index in the input array, taking striding, dilation
2870   // and padding into account. An index in the padding will be out of the bounds
2871   // of the array.
2872   const auto calculate_input_index = [this](llvm::Value* output_index,
2873                                             llvm::Value* kernel_index,
2874                                             const WindowDimension& window_dim) {
2875     llvm::Value* strided_index =
2876         NSWMul(output_index, b_->getInt64(window_dim.stride()));
2877     llvm::Value* dilated_kernel_index =
2878         NSWMul(kernel_index, b_->getInt64(window_dim.window_dilation()));
2879     return NSWSub(NSWAdd(strided_index, dilated_kernel_index),
2880                   b_->getInt64(window_dim.padding_low()));
2881   };
2882   std::vector<llvm::Value*> input_spatial(num_spatial_dims);
2883   for (int i = 0; i < num_spatial_dims; ++i) {
2884     input_spatial[i] = calculate_input_index(
2885         output_spatial[i], kernel_spatial[i], window.dimensions(i));
2886   }
2887 
2888   // We need to check if 0 <= input dim < bound, as otherwise we are in the
2889   // padding so that we can skip the computation. That is equivalent to input
2890   // dim < bound as an *unsigned* comparison, since a negative value will wrap
2891   // to a large positive value. The input dim is dilated, so we need to dilate
2892   // the bound as well to match.
2893 
2894   // Also need to check that the input coordinates are not in one of the
2895   // holes created by base dilation.
2896   const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
2897     llvm::Value* remainder = SRem(input_index, b_->getInt64(base_dilation));
2898     return ICmpEQ(remainder, b_->getInt64(0));
2899   };
2900 
2901   llvm::Value* in_bounds_condition = b_->getInt1(true);
2902   for (int i = 0; i < num_spatial_dims; ++i) {
2903     llvm::ConstantInt* input_bound = b_->getInt64(window_util::DilatedBound(
2904         lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
2905         window.dimensions(i).base_dilation()));
2906     llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound);
2907     llvm::Value* dim_not_in_hole =
2908         not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
2909     llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole);
2910     in_bounds_condition = And(in_bounds_condition, dim_ok);
2911   }
2912 
2913   // Now we need to map the dilated base coordinates back to the actual
2914   // data indices on the lhs.
2915   const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
2916     return SDiv(input_index, b_->getInt64(base_dilation));
2917   };
2918   for (int i = 0; i < num_spatial_dims; ++i) {
2919     input_spatial[i] =
2920         undilate(input_spatial[i], window.dimensions(i).base_dilation());
2921   }
2922 
2923   llvm_ir::LlvmIfData if_data =
2924       llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", b_);
2925   SetToFirstInsertPoint(if_data.true_block, b_);
2926 
2927   // We are not in the padding, so carry out the computation.
2928   int num_dims = num_spatial_dims + 2;
2929   std::vector<llvm::Value*> input_multi_index(num_dims);
2930   for (int i = 0; i < num_spatial_dims; ++i) {
2931     input_multi_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
2932   }
2933   input_multi_index[dnums.input_feature_dimension()] = input_feature;
2934   input_multi_index[dnums.input_batch_dimension()] = batch;
2935 
2936   std::vector<llvm::Value*> kernel_multi_index(num_dims);
2937   for (int i = 0; i < num_spatial_dims; ++i) {
2938     kernel_multi_index[dnums.kernel_spatial_dimensions(i)] =
2939         window.dimensions(i).window_reversal()
2940             ? NSWSub(b_->getInt64(window.dimensions(i).size() - 1),
2941                      kernel_spatial[i])
2942             : kernel_spatial[i];
2943   }
2944 
2945   kernel_multi_index[dnums.kernel_input_feature_dimension()] = input_feature;
2946   kernel_multi_index[dnums.kernel_output_feature_dimension()] = output_feature;
2947 
2948   llvm_ir::IrArray::Index input_index(input_multi_index, lhs->shape(),
2949                                       b_->getInt64Ty());
2950   TF_ASSIGN_OR_RETURN(llvm::Value* const input_value,
2951                       input_generator(input_index));
2952   llvm_ir::IrArray::Index kernel_index(kernel_multi_index, rhs->shape(),
2953                                        b_->getInt64Ty());
2954   TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value,
2955                       kernel_generator(kernel_index));
2956   llvm::Value* sum = EmitMulAdd(input_value, kernel_value, Load(sum_address),
2957                                 convolution->shape().element_type());
2958   Store(sum, sum_address);
2959 
2960   SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
2961   return FPCast(Load(sum_address), lhs_llvm_type);
2962 }
2963 
2964 // Evaluate polynomial using Horner's method.
EvaluatePolynomial(llvm::Type * type,llvm::Value * x,absl::Span<const double> coefficients)2965 StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
2966     llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
2967   llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
2968   for (const double c : coefficients) {
2969     poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
2970   }
2971   return poly;
2972 }
2973 
2974 }  // namespace xla
2975