1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
17 
18 #include "llvm/IR/Function.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/IR/Verifier.h"
22 #include "llvm/Transforms/Utils/Cloning.h"
23 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 namespace cpu {
29 namespace runtime {
30 
31 const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
32 const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
33 const char* const kTanhV16F32SymbolName = "__xla_cpu_runtime_TanhV16F32";
34 const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
35 const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
36 const char* const kExpV16F32SymbolName = "__xla_cpu_runtime_ExpV16F32";
37 const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX";
38 const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX";
39 const char* const kLogV16F32SymbolName = "__xla_cpu_runtime_LogV16F32AVX";
40 
41 namespace {
42 
43 // Removes 'fn' from the list of symbols to keep in 'module'.
RemoveFunctionFromUsedList(llvm::Module * module,llvm::Function * fn)44 void RemoveFunctionFromUsedList(llvm::Module* module, llvm::Function* fn) {
45   llvm::GlobalVariable* used = module->getGlobalVariable("llvm.compiler.used");
46   if (!used) {
47     return;
48   }
49 
50   llvm::Type* int8_ptr_type = llvm::Type::getInt8PtrTy(module->getContext());
51   llvm::Constant* casted_fn = llvm::ConstantExpr::getBitCast(fn, int8_ptr_type);
52   auto* initializer = llvm::cast<llvm::ConstantArray>(used->getInitializer());
53   llvm::SmallVector<llvm::Constant*, 4> new_initializer;
54   for (auto& op : initializer->operands()) {
55     if (op != casted_fn) {
56       new_initializer.push_back(llvm::cast<llvm::Constant>(op));
57     }
58   }
59 
60   if (new_initializer.size() == initializer->getNumOperands()) {
61     return;
62   }
63 
64   used->eraseFromParent();
65   if (!new_initializer.empty()) {
66     llvm::ArrayType* array_type =
67         llvm::ArrayType::get(int8_ptr_type, new_initializer.size());
68     used = new llvm::GlobalVariable(
69         *module, array_type, /*isConstant=*/false,
70         llvm::GlobalValue::AppendingLinkage,
71         llvm::ConstantArray::get(array_type, new_initializer),
72         "llvm.compiler.used");
73     used->setSection("llvm.metadata");
74   }
75 }
76 
77 // Replaces calls to the function `fn_name` with the code generated by
78 // fn_body_generator.
79 //
80 // We assume that fn_name accepts either a scalar f32 or a vector of
81 // vector_width f32s, and that fn_body_generator generates a function body with
82 // the same inputs/outputs as fn_name.
RewriteCalls(llvm::Module * module,const char * fn_name,std::function<llvm::Value * (llvm::IRBuilder<> * b,llvm::Value * input,int32 vector_width)> fn_body_generator,int32 vector_width,llvm::FastMathFlags fast_math_flags)83 void RewriteCalls(
84     llvm::Module* module, const char* fn_name,
85     std::function<llvm::Value*(llvm::IRBuilder<>* b, llvm::Value* input,
86                                int32 vector_width)>
87         fn_body_generator,
88     int32 vector_width, llvm::FastMathFlags fast_math_flags) {
89   llvm::Function* fn = module->getFunction(fn_name);
90   if (fn == nullptr) {
91     // If the function declaration is not present in the module, there can't be
92     // any calls to resolve.  Don't emit the function in this case.
93     return;
94   }
95 
96   // Our task is to generate a function body for `fn`, but we can't generate a
97   // function body for an LLVM intrinsic. So if fn is an intrinsic, replace it
98   // with a new function.
99   if (fn->isIntrinsic()) {
100     llvm::Function* new_fn = llvm::Function::Create(
101         fn->getFunctionType(), llvm::GlobalValue::InternalLinkage,
102         llvm::Twine("xla_impl.") + fn_name, module);
103     fn->replaceAllUsesWith(new_fn);
104     fn->eraseFromParent();
105     fn = new_fn;
106   }
107 
108   llvm::LLVMContext* context = &module->getContext();
109 
110   llvm::BasicBlock* fn_body = llvm::BasicBlock::Create(*context, "body", fn);
111   llvm::IRBuilder<> b(fn_body);
112   b.setFastMathFlags(fast_math_flags);
113 
114   llvm::Value* input = &*fn->arg_begin();
115 
116   // Upcast to vector type if input is a scalar.
117   if (vector_width == 1) {
118     llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1, false);
119     input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input,
120                                   uint64_t{0});
121   }
122 
123   // Generate the vectorized code.
124   CHECK_EQ(
125       vector_width,
126       llvm::cast<llvm::FixedVectorType>(input->getType())->getNumElements());
127   llvm::Value* result = fn_body_generator(&b, input, vector_width);
128 
129   // Downcast result to scalar type if necessary.
130   if (vector_width == 1) {
131     result = b.CreateExtractElement(result, uint64_t{0});
132   }
133   b.CreateRet(result);
134   DCHECK(!llvm::verifyFunction(*fn));
135 
136   // Force-inline `fn` into all of its callers and then delete `fn`.
137   //
138   // TODO(b/73081976): Should we avoid inlining these in some cases?
139   std::vector<llvm::CallInst*> calls_to_inline;
140   for (auto* user : fn->users()) {
141     if (auto* call = llvm::dyn_cast<llvm::CallInst>(user)) {
142       calls_to_inline.push_back(call);
143     }
144   }
145   for (auto* call_to_inline : calls_to_inline) {
146     llvm::InlineFunctionInfo inline_function_info;
147     CHECK(llvm::InlineFunction(*call_to_inline, inline_function_info)
148               .isSuccess());
149   }
150   // LLVM's InjectTLIMappings adds functions that might be used for
151   // vectorization to 'llvm.compiler.used'. Remove it before deleting the
152   // function.
153   RemoveFunctionFromUsedList(module, fn);
154   fn->eraseFromParent();
155 }
156 
GenerateVF32Tanh(llvm::IRBuilder<> * b,llvm::Value * input,int32)157 llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input,
158                               int32 /*vector_width*/) {
159   return llvm_ir::EmitFastTanh(b, input);
160 }
161 
GenerateVF32Exp(llvm::IRBuilder<> * b,llvm::Value * input,int32 vector_width)162 llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input,
163                              int32 vector_width) {
164   VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32");
165 
166   // This implements the same polynomial approximation as implemented in Cephes.
167   const llvm::APFloat half = GetIeeeF32(0.5);
168   const llvm::APFloat one = GetIeeeF32(1);
169 
170   // The constant 1/log(2),
171   const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341);
172 
173   const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375);
174   const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4);
175 
176   const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4);
177   const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3);
178   const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3);
179   const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2);
180   const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
181   const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
182 
183   // To compute e^x, we re-express it as
184   //
185   //   e^x = e^(a + b)
186   //       = e^(a + n log(2))
187   //       = e^a * 2^n.
188   //
189   // We choose n = round(x / log(2)), restricting the value of `a` to
190   // (-log(2)/2, log(2)/2).  We then use a polynomial to compute e^a. The
191   // relative error between our approximation and the true value of e^a is less
192   // than 2^-22.5 for all values of `a` within this range.
193 
194   // Restrict input to a small range, including some values that evaluate to
195   // +/- inf.  Note that for our lower bound, we choose log(2^-126) instead of
196   // log(F32_EPSILON). We do so because this routine always flushes denormal
197   // floating points to 0. Therefore, we only need to worry about exponentiating
198   // up to the smallest representable non-denormal floating point, which is
199   // 2^-126.
200   //
201   // Our computations below aren't particularly sensitive to the exact choices
202   // here, so we choose values a bit larger/smaller than
203   //
204   //   log(F32_MAX) = 88.723...
205   //   log(2^-126) = -87.337...
206   input = vsl.Clamp(input, GetIeeeF32(-87.8), GetIeeeF32(88.8));
207 
208   llvm::Value* x = input;
209 
210   // Calculates n = floor(input / log(2) + 0.5) = round(input / log(2))
211   llvm::Value* n = vsl.Floor(vsl.MulAdd(input, cephes_LOG2EF, half));
212 
213   // When we eventually do the multiplication in e^a * 2^n, we need to handle
214   // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
215   // (so e^a * 2^n != inf).  There's a similar problem for n < -126, the
216   // smallest fp32 exponent.
217   //
218   // A straightforward solution would be to detect n out of range and split it
219   // up, doing
220   //
221   //   e^a * 2^n = e^a * 2^(n1 + n2)
222   //             = (2^n1 * e^a) * 2^n2.
223   //
224   // But it turns out this approach is quite slow, probably because it
225   // manipulates subnormal values.
226   //
227   // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
228   // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
229   // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
230   // this value of `a` is outside our previously specified range, e^a will still
231   // only have a relative error of approximately 2^-16 at worse. In practice
232   // this seems to work well enough; it passes our exhaustive tests, breaking
233   // only one result, and by one ulp (we return exp(88.7228394) = max-float but
234   // we should return inf).
235   //
236   // In the case where n' = -127, the original input value of x is so small that
237   // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
238   // normal floating point, and since we flush denormals, we simply return 0. We
239   // do this in a branchless way by observing that our code for constructing 2^n
240   // produces 0 if n = -127.
241   //
242   // The proof that n' = -127 implies e^x < 2^-126 is as follows:
243   //
244   //    n' = -127 implies n <= -127
245   //              implies round(x / log(2)) <= -127
246   //              implies x/log(2) < -126.5
247   //              implies x < -126.5 * log(2)
248   //              implies e^x < e^(-126.5 * log(2))
249   //              implies e^x < 2^-126.5 < 2^-126
250   //
251   //    This proves that n' = -127 implies e^x < 2^-126.
252   n = vsl.Clamp(n, GetIeeeF32(-127), GetIeeeF32(127));
253 
254   // Computes x = x - n' * log(2), the value for `a`
255   x = vsl.Sub(x, vsl.Mul(cephes_exp_C1, n));
256   x = vsl.Sub(x, vsl.Mul(cephes_exp_C2, n));
257 
258   // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
259   llvm::Value* z = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1);
260   z = vsl.MulAdd(z, x, cephes_exp_p2);
261   z = vsl.MulAdd(z, x, cephes_exp_p3);
262   z = vsl.MulAdd(z, x, cephes_exp_p4);
263   z = vsl.MulAdd(z, x, cephes_exp_p5);
264   z = vsl.MulAdd(z, vsl.Mul(x, x), x);
265   z = vsl.Add(one, z);
266 
267   // Convert n' to an i32.  This is safe because we clamped it above.
268   llvm::Value* n_i32 = b->CreateFPToSI(
269       n, llvm::VectorType::get(b->getInt32Ty(), vector_width, false));
270 
271   auto splat_i32 = [&](int32 v) {
272     return b->CreateVectorSplat(vector_width, b->getInt32(v));
273   };
274 
275   // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
276   const int32 kF32SignificandBits = 23;
277   llvm::Value* exp_bias = splat_i32(0x7f);
278   llvm::Value* pow2 =
279       b->CreateBitCast(b->CreateShl(b->CreateAdd(n_i32, exp_bias),
280                                     splat_i32(kF32SignificandBits)),
281                        vsl.vector_type());
282 
283   // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
284   return vsl.Mul(z, pow2);
285 }
286 
GenerateVF32Log(llvm::IRBuilder<> * b,llvm::Value * input,int32 vector_width)287 llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,
288                              int32 vector_width) {
289   VectorSupportLibrary vsl(F32, vector_width, b, "log_f32");
290 
291   const llvm::APFloat half = GetIeeeF32(0.5);
292   const llvm::APFloat one = GetIeeeF32(1.0);
293 
294   // This implements the same polynomial approximation as implemented in Eigen3.
295   // Returns NaN for x < 0, -INF for x = 0
296   const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524);
297   const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2);
298   const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1);
299   const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1);
300   const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1);
301   const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1);
302   const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1);
303   const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1);
304   const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1);
305   const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1);
306   const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4);
307   const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375);
308 
309   // The smallest non denormalized float number.
310   const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000);
311   const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000);
312   const llvm::APFloat pos_inf = GetIeeeF32FromBitwiseRep(0x7f800000);
313   const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000);
314 
315   // invalid_mask is set if x is negative or NaN (and therefore output
316   // must be NaN).
317   llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector());
318   llvm::Value* is_zero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
319   llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf);
320 
321   // Cut off denormalized stuff.
322   // Always allow fast max because we are checking for the nan above.
323   llvm::Value* tmp0 =
324       vsl.Max(min_norm_pos, input, /*enable_fast_min_max=*/true);
325 
326   // VectorSupportLibrary (intentionally) can't juggle more than one type at a
327   // time so drop down to IRBuilder for this bit.
328   llvm::Value* vector_constant_0x7f =
329       b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
330   llvm::Value* vector_constant_23 =
331       b->CreateVectorSplat(vector_width, b->getInt32(23));
332   llvm::Type* i32_vector_type =
333       llvm::VectorType::get(b->getInt32Ty(), vector_width, false);
334 
335   llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type),
336                                     vector_constant_23);
337 
338   // Keep only the fractional part.
339   tmp0 = vsl.FloatAnd(tmp0, inv_mant_mask);
340   tmp0 = vsl.FloatOr(tmp0, half);
341 
342   emm0 = b->CreateSub(emm0, vector_constant_0x7f);
343   llvm::Value* e = vsl.Add(one, b->CreateSIToFP(emm0, vsl.vector_type()));
344 
345   // part2:
346   //   if( x < SQRTHF ) {
347   //     e -= 1;
348   //     x = x + x - 1.0;
349   //   } else { x = x - 1.0; }
350   llvm::Value* mask = vsl.FCmpOLTMask(tmp0, cephes_SQRTHF);
351   llvm::Value* tmp1 = vsl.FloatAnd(tmp0, mask);
352   tmp0 = vsl.Sub(tmp0, one);
353   e = vsl.Sub(e, vsl.FloatAnd(mask, one));
354   tmp0 = vsl.Add(tmp0, tmp1);
355 
356   llvm::Value* x2 = vsl.Mul(tmp0, tmp0);
357   llvm::Value* x3 = vsl.Mul(x2, tmp0);
358 
359   llvm::Value *y, *y1, *y2;
360   y = vsl.MulAdd(tmp0, cephes_log_p0, cephes_log_p1);
361   y1 = vsl.MulAdd(tmp0, cephes_log_p3, cephes_log_p4);
362   y2 = vsl.MulAdd(tmp0, cephes_log_p6, cephes_log_p7);
363   y = vsl.MulAdd(y, tmp0, cephes_log_p2);
364   y1 = vsl.MulAdd(y1, tmp0, cephes_log_p5);
365   y2 = vsl.MulAdd(y2, tmp0, cephes_log_p8);
366   y = vsl.MulAdd(y, x3, y1);
367   y = vsl.MulAdd(y, x3, y2);
368   y = vsl.Mul(y, x3);
369 
370   y1 = vsl.Mul(cephes_log_q1, e);
371   llvm::Value* tmp2 = vsl.Mul(half, x2);
372   y = vsl.Add(y, y1);
373   tmp0 = vsl.Sub(tmp0, tmp2);
374   y2 = vsl.Mul(cephes_log_q2, e);
375   tmp0 = vsl.Add(tmp0, y);
376   tmp0 = vsl.Add(tmp0, y2);
377 
378   // Contains +/-inf where +/-inf is the correct answer, otherwise 0.
379   llvm::Value* result_inf = vsl.FloatOr(vsl.FloatAnd(is_zero_mask, minus_inf),
380                                         vsl.FloatAnd(is_pos_inf_mask, pos_inf));
381 
382   // Contains a finite result or nan.  This is the correct answer only if both
383   // result_minus_inf and result_pos_inf are both 0.
384   //
385   // (This implementation works because 0xffffffff is a nan.)
386   llvm::Value* result_finite_or_nan = vsl.FloatOr(tmp0, invalid_mask);
387 
388   // Combine the above into a final result.
389   return vsl.FloatOr(result_inf,
390                      vsl.FloatAndNot(vsl.FloatOr(is_zero_mask, is_pos_inf_mask),
391                                      result_finite_or_nan));
392 }
393 }  // namespace
394 
RewriteIRRuntimeFunctions(llvm::Module * module,llvm::FastMathFlags fast_math_flags)395 void RewriteIRRuntimeFunctions(llvm::Module* module,
396                                llvm::FastMathFlags fast_math_flags) {
397   // Curry some params to RewriteCalls.
398   auto rewrite_calls =
399       std::bind(RewriteCalls, module, std::placeholders::_1,
400                 std::placeholders::_2, std::placeholders::_3, fast_math_flags);
401 
402   rewrite_calls("tanhf", GenerateVF32Tanh, /*vector_width=*/1);
403   rewrite_calls("llvm.tanh.f32", GenerateVF32Tanh, /*vector_width=*/1);
404   rewrite_calls(kTanhV4F32SymbolName, GenerateVF32Tanh, /*vector_width=*/4);
405   rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8);
406   rewrite_calls(kTanhV16F32SymbolName, GenerateVF32Tanh, /*vector_width=*/16);
407 
408   rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1);
409   rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1);
410   rewrite_calls(kExpV4F32SymbolName, GenerateVF32Exp, /*vector_width=*/4);
411   rewrite_calls(kExpV8F32SymbolName, GenerateVF32Exp, /*vector_width=*/8);
412   rewrite_calls(kExpV16F32SymbolName, GenerateVF32Exp, /*vector_width=*/16);
413 
414   rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1);
415   rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1);
416   rewrite_calls(kLogV4F32SymbolName, GenerateVF32Log, /*vector_width=*/4);
417   rewrite_calls(kLogV8F32SymbolName, GenerateVF32Log, /*vector_width=*/8);
418   rewrite_calls(kLogV16F32SymbolName, GenerateVF32Log, /*vector_width=*/16);
419 }
420 
421 }  // namespace runtime
422 }  // namespace cpu
423 }  // namespace xla
424