1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
17
18 #include "llvm/IR/Function.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/IR/Verifier.h"
22 #include "llvm/Transforms/Utils/Cloning.h"
23 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace xla {
28 namespace cpu {
29 namespace runtime {
30
31 const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
32 const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
33 const char* const kTanhV16F32SymbolName = "__xla_cpu_runtime_TanhV16F32";
34 const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
35 const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
36 const char* const kExpV16F32SymbolName = "__xla_cpu_runtime_ExpV16F32";
37 const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX";
38 const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX";
39 const char* const kLogV16F32SymbolName = "__xla_cpu_runtime_LogV16F32AVX";
40
41 namespace {
42
43 // Removes 'fn' from the list of symbols to keep in 'module'.
RemoveFunctionFromUsedList(llvm::Module * module,llvm::Function * fn)44 void RemoveFunctionFromUsedList(llvm::Module* module, llvm::Function* fn) {
45 llvm::GlobalVariable* used = module->getGlobalVariable("llvm.compiler.used");
46 if (!used) {
47 return;
48 }
49
50 llvm::Type* int8_ptr_type = llvm::Type::getInt8PtrTy(module->getContext());
51 llvm::Constant* casted_fn = llvm::ConstantExpr::getBitCast(fn, int8_ptr_type);
52 auto* initializer = llvm::cast<llvm::ConstantArray>(used->getInitializer());
53 llvm::SmallVector<llvm::Constant*, 4> new_initializer;
54 for (auto& op : initializer->operands()) {
55 if (op != casted_fn) {
56 new_initializer.push_back(llvm::cast<llvm::Constant>(op));
57 }
58 }
59
60 if (new_initializer.size() == initializer->getNumOperands()) {
61 return;
62 }
63
64 used->eraseFromParent();
65 if (!new_initializer.empty()) {
66 llvm::ArrayType* array_type =
67 llvm::ArrayType::get(int8_ptr_type, new_initializer.size());
68 used = new llvm::GlobalVariable(
69 *module, array_type, /*isConstant=*/false,
70 llvm::GlobalValue::AppendingLinkage,
71 llvm::ConstantArray::get(array_type, new_initializer),
72 "llvm.compiler.used");
73 used->setSection("llvm.metadata");
74 }
75 }
76
77 // Replaces calls to the function `fn_name` with the code generated by
78 // fn_body_generator.
79 //
80 // We assume that fn_name accepts either a scalar f32 or a vector of
81 // vector_width f32s, and that fn_body_generator generates a function body with
82 // the same inputs/outputs as fn_name.
RewriteCalls(llvm::Module * module,const char * fn_name,std::function<llvm::Value * (llvm::IRBuilder<> * b,llvm::Value * input,int32 vector_width)> fn_body_generator,int32 vector_width,llvm::FastMathFlags fast_math_flags)83 void RewriteCalls(
84 llvm::Module* module, const char* fn_name,
85 std::function<llvm::Value*(llvm::IRBuilder<>* b, llvm::Value* input,
86 int32 vector_width)>
87 fn_body_generator,
88 int32 vector_width, llvm::FastMathFlags fast_math_flags) {
89 llvm::Function* fn = module->getFunction(fn_name);
90 if (fn == nullptr) {
91 // If the function declaration is not present in the module, there can't be
92 // any calls to resolve. Don't emit the function in this case.
93 return;
94 }
95
96 // Our task is to generate a function body for `fn`, but we can't generate a
97 // function body for an LLVM intrinsic. So if fn is an intrinsic, replace it
98 // with a new function.
99 if (fn->isIntrinsic()) {
100 llvm::Function* new_fn = llvm::Function::Create(
101 fn->getFunctionType(), llvm::GlobalValue::InternalLinkage,
102 llvm::Twine("xla_impl.") + fn_name, module);
103 fn->replaceAllUsesWith(new_fn);
104 fn->eraseFromParent();
105 fn = new_fn;
106 }
107
108 llvm::LLVMContext* context = &module->getContext();
109
110 llvm::BasicBlock* fn_body = llvm::BasicBlock::Create(*context, "body", fn);
111 llvm::IRBuilder<> b(fn_body);
112 b.setFastMathFlags(fast_math_flags);
113
114 llvm::Value* input = &*fn->arg_begin();
115
116 // Upcast to vector type if input is a scalar.
117 if (vector_width == 1) {
118 llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1, false);
119 input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input,
120 uint64_t{0});
121 }
122
123 // Generate the vectorized code.
124 CHECK_EQ(
125 vector_width,
126 llvm::cast<llvm::FixedVectorType>(input->getType())->getNumElements());
127 llvm::Value* result = fn_body_generator(&b, input, vector_width);
128
129 // Downcast result to scalar type if necessary.
130 if (vector_width == 1) {
131 result = b.CreateExtractElement(result, uint64_t{0});
132 }
133 b.CreateRet(result);
134 DCHECK(!llvm::verifyFunction(*fn));
135
136 // Force-inline `fn` into all of its callers and then delete `fn`.
137 //
138 // TODO(b/73081976): Should we avoid inlining these in some cases?
139 std::vector<llvm::CallInst*> calls_to_inline;
140 for (auto* user : fn->users()) {
141 if (auto* call = llvm::dyn_cast<llvm::CallInst>(user)) {
142 calls_to_inline.push_back(call);
143 }
144 }
145 for (auto* call_to_inline : calls_to_inline) {
146 llvm::InlineFunctionInfo inline_function_info;
147 CHECK(llvm::InlineFunction(*call_to_inline, inline_function_info)
148 .isSuccess());
149 }
150 // LLVM's InjectTLIMappings adds functions that might be used for
151 // vectorization to 'llvm.compiler.used'. Remove it before deleting the
152 // function.
153 RemoveFunctionFromUsedList(module, fn);
154 fn->eraseFromParent();
155 }
156
GenerateVF32Tanh(llvm::IRBuilder<> * b,llvm::Value * input,int32)157 llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input,
158 int32 /*vector_width*/) {
159 return llvm_ir::EmitFastTanh(b, input);
160 }
161
GenerateVF32Exp(llvm::IRBuilder<> * b,llvm::Value * input,int32 vector_width)162 llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input,
163 int32 vector_width) {
164 VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32");
165
166 // This implements the same polynomial approximation as implemented in Cephes.
167 const llvm::APFloat half = GetIeeeF32(0.5);
168 const llvm::APFloat one = GetIeeeF32(1);
169
170 // The constant 1/log(2),
171 const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341);
172
173 const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375);
174 const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4);
175
176 const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4);
177 const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3);
178 const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3);
179 const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2);
180 const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
181 const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
182
183 // To compute e^x, we re-express it as
184 //
185 // e^x = e^(a + b)
186 // = e^(a + n log(2))
187 // = e^a * 2^n.
188 //
189 // We choose n = round(x / log(2)), restricting the value of `a` to
190 // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The
191 // relative error between our approximation and the true value of e^a is less
192 // than 2^-22.5 for all values of `a` within this range.
193
194 // Restrict input to a small range, including some values that evaluate to
195 // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of
196 // log(F32_EPSILON). We do so because this routine always flushes denormal
197 // floating points to 0. Therefore, we only need to worry about exponentiating
198 // up to the smallest representable non-denormal floating point, which is
199 // 2^-126.
200 //
201 // Our computations below aren't particularly sensitive to the exact choices
202 // here, so we choose values a bit larger/smaller than
203 //
204 // log(F32_MAX) = 88.723...
205 // log(2^-126) = -87.337...
206 input = vsl.Clamp(input, GetIeeeF32(-87.8), GetIeeeF32(88.8));
207
208 llvm::Value* x = input;
209
210 // Calculates n = floor(input / log(2) + 0.5) = round(input / log(2))
211 llvm::Value* n = vsl.Floor(vsl.MulAdd(input, cephes_LOG2EF, half));
212
213 // When we eventually do the multiplication in e^a * 2^n, we need to handle
214 // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
215 // (so e^a * 2^n != inf). There's a similar problem for n < -126, the
216 // smallest fp32 exponent.
217 //
218 // A straightforward solution would be to detect n out of range and split it
219 // up, doing
220 //
221 // e^a * 2^n = e^a * 2^(n1 + n2)
222 // = (2^n1 * e^a) * 2^n2.
223 //
224 // But it turns out this approach is quite slow, probably because it
225 // manipulates subnormal values.
226 //
227 // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
228 // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
229 // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
230 // this value of `a` is outside our previously specified range, e^a will still
231 // only have a relative error of approximately 2^-16 at worse. In practice
232 // this seems to work well enough; it passes our exhaustive tests, breaking
233 // only one result, and by one ulp (we return exp(88.7228394) = max-float but
234 // we should return inf).
235 //
236 // In the case where n' = -127, the original input value of x is so small that
237 // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
238 // normal floating point, and since we flush denormals, we simply return 0. We
239 // do this in a branchless way by observing that our code for constructing 2^n
240 // produces 0 if n = -127.
241 //
242 // The proof that n' = -127 implies e^x < 2^-126 is as follows:
243 //
244 // n' = -127 implies n <= -127
245 // implies round(x / log(2)) <= -127
246 // implies x/log(2) < -126.5
247 // implies x < -126.5 * log(2)
248 // implies e^x < e^(-126.5 * log(2))
249 // implies e^x < 2^-126.5 < 2^-126
250 //
251 // This proves that n' = -127 implies e^x < 2^-126.
252 n = vsl.Clamp(n, GetIeeeF32(-127), GetIeeeF32(127));
253
254 // Computes x = x - n' * log(2), the value for `a`
255 x = vsl.Sub(x, vsl.Mul(cephes_exp_C1, n));
256 x = vsl.Sub(x, vsl.Mul(cephes_exp_C2, n));
257
258 // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
259 llvm::Value* z = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1);
260 z = vsl.MulAdd(z, x, cephes_exp_p2);
261 z = vsl.MulAdd(z, x, cephes_exp_p3);
262 z = vsl.MulAdd(z, x, cephes_exp_p4);
263 z = vsl.MulAdd(z, x, cephes_exp_p5);
264 z = vsl.MulAdd(z, vsl.Mul(x, x), x);
265 z = vsl.Add(one, z);
266
267 // Convert n' to an i32. This is safe because we clamped it above.
268 llvm::Value* n_i32 = b->CreateFPToSI(
269 n, llvm::VectorType::get(b->getInt32Ty(), vector_width, false));
270
271 auto splat_i32 = [&](int32 v) {
272 return b->CreateVectorSplat(vector_width, b->getInt32(v));
273 };
274
275 // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
276 const int32 kF32SignificandBits = 23;
277 llvm::Value* exp_bias = splat_i32(0x7f);
278 llvm::Value* pow2 =
279 b->CreateBitCast(b->CreateShl(b->CreateAdd(n_i32, exp_bias),
280 splat_i32(kF32SignificandBits)),
281 vsl.vector_type());
282
283 // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
284 return vsl.Mul(z, pow2);
285 }
286
GenerateVF32Log(llvm::IRBuilder<> * b,llvm::Value * input,int32 vector_width)287 llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,
288 int32 vector_width) {
289 VectorSupportLibrary vsl(F32, vector_width, b, "log_f32");
290
291 const llvm::APFloat half = GetIeeeF32(0.5);
292 const llvm::APFloat one = GetIeeeF32(1.0);
293
294 // This implements the same polynomial approximation as implemented in Eigen3.
295 // Returns NaN for x < 0, -INF for x = 0
296 const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524);
297 const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2);
298 const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1);
299 const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1);
300 const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1);
301 const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1);
302 const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1);
303 const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1);
304 const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1);
305 const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1);
306 const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4);
307 const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375);
308
309 // The smallest non denormalized float number.
310 const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000);
311 const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000);
312 const llvm::APFloat pos_inf = GetIeeeF32FromBitwiseRep(0x7f800000);
313 const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000);
314
315 // invalid_mask is set if x is negative or NaN (and therefore output
316 // must be NaN).
317 llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector());
318 llvm::Value* is_zero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
319 llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf);
320
321 // Cut off denormalized stuff.
322 // Always allow fast max because we are checking for the nan above.
323 llvm::Value* tmp0 =
324 vsl.Max(min_norm_pos, input, /*enable_fast_min_max=*/true);
325
326 // VectorSupportLibrary (intentionally) can't juggle more than one type at a
327 // time so drop down to IRBuilder for this bit.
328 llvm::Value* vector_constant_0x7f =
329 b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
330 llvm::Value* vector_constant_23 =
331 b->CreateVectorSplat(vector_width, b->getInt32(23));
332 llvm::Type* i32_vector_type =
333 llvm::VectorType::get(b->getInt32Ty(), vector_width, false);
334
335 llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type),
336 vector_constant_23);
337
338 // Keep only the fractional part.
339 tmp0 = vsl.FloatAnd(tmp0, inv_mant_mask);
340 tmp0 = vsl.FloatOr(tmp0, half);
341
342 emm0 = b->CreateSub(emm0, vector_constant_0x7f);
343 llvm::Value* e = vsl.Add(one, b->CreateSIToFP(emm0, vsl.vector_type()));
344
345 // part2:
346 // if( x < SQRTHF ) {
347 // e -= 1;
348 // x = x + x - 1.0;
349 // } else { x = x - 1.0; }
350 llvm::Value* mask = vsl.FCmpOLTMask(tmp0, cephes_SQRTHF);
351 llvm::Value* tmp1 = vsl.FloatAnd(tmp0, mask);
352 tmp0 = vsl.Sub(tmp0, one);
353 e = vsl.Sub(e, vsl.FloatAnd(mask, one));
354 tmp0 = vsl.Add(tmp0, tmp1);
355
356 llvm::Value* x2 = vsl.Mul(tmp0, tmp0);
357 llvm::Value* x3 = vsl.Mul(x2, tmp0);
358
359 llvm::Value *y, *y1, *y2;
360 y = vsl.MulAdd(tmp0, cephes_log_p0, cephes_log_p1);
361 y1 = vsl.MulAdd(tmp0, cephes_log_p3, cephes_log_p4);
362 y2 = vsl.MulAdd(tmp0, cephes_log_p6, cephes_log_p7);
363 y = vsl.MulAdd(y, tmp0, cephes_log_p2);
364 y1 = vsl.MulAdd(y1, tmp0, cephes_log_p5);
365 y2 = vsl.MulAdd(y2, tmp0, cephes_log_p8);
366 y = vsl.MulAdd(y, x3, y1);
367 y = vsl.MulAdd(y, x3, y2);
368 y = vsl.Mul(y, x3);
369
370 y1 = vsl.Mul(cephes_log_q1, e);
371 llvm::Value* tmp2 = vsl.Mul(half, x2);
372 y = vsl.Add(y, y1);
373 tmp0 = vsl.Sub(tmp0, tmp2);
374 y2 = vsl.Mul(cephes_log_q2, e);
375 tmp0 = vsl.Add(tmp0, y);
376 tmp0 = vsl.Add(tmp0, y2);
377
378 // Contains +/-inf where +/-inf is the correct answer, otherwise 0.
379 llvm::Value* result_inf = vsl.FloatOr(vsl.FloatAnd(is_zero_mask, minus_inf),
380 vsl.FloatAnd(is_pos_inf_mask, pos_inf));
381
382 // Contains a finite result or nan. This is the correct answer only if both
383 // result_minus_inf and result_pos_inf are both 0.
384 //
385 // (This implementation works because 0xffffffff is a nan.)
386 llvm::Value* result_finite_or_nan = vsl.FloatOr(tmp0, invalid_mask);
387
388 // Combine the above into a final result.
389 return vsl.FloatOr(result_inf,
390 vsl.FloatAndNot(vsl.FloatOr(is_zero_mask, is_pos_inf_mask),
391 result_finite_or_nan));
392 }
393 } // namespace
394
RewriteIRRuntimeFunctions(llvm::Module * module,llvm::FastMathFlags fast_math_flags)395 void RewriteIRRuntimeFunctions(llvm::Module* module,
396 llvm::FastMathFlags fast_math_flags) {
397 // Curry some params to RewriteCalls.
398 auto rewrite_calls =
399 std::bind(RewriteCalls, module, std::placeholders::_1,
400 std::placeholders::_2, std::placeholders::_3, fast_math_flags);
401
402 rewrite_calls("tanhf", GenerateVF32Tanh, /*vector_width=*/1);
403 rewrite_calls("llvm.tanh.f32", GenerateVF32Tanh, /*vector_width=*/1);
404 rewrite_calls(kTanhV4F32SymbolName, GenerateVF32Tanh, /*vector_width=*/4);
405 rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8);
406 rewrite_calls(kTanhV16F32SymbolName, GenerateVF32Tanh, /*vector_width=*/16);
407
408 rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1);
409 rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1);
410 rewrite_calls(kExpV4F32SymbolName, GenerateVF32Exp, /*vector_width=*/4);
411 rewrite_calls(kExpV8F32SymbolName, GenerateVF32Exp, /*vector_width=*/8);
412 rewrite_calls(kExpV16F32SymbolName, GenerateVF32Exp, /*vector_width=*/16);
413
414 rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1);
415 rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1);
416 rewrite_calls(kLogV4F32SymbolName, GenerateVF32Log, /*vector_width=*/4);
417 rewrite_calls(kLogV8F32SymbolName, GenerateVF32Log, /*vector_width=*/8);
418 rewrite_calls(kLogV16F32SymbolName, GenerateVF32Log, /*vector_width=*/16);
419 }
420
421 } // namespace runtime
422 } // namespace cpu
423 } // namespace xla
424