1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "llvm/ADT/Triple.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/GlobalVariable.h"
29 #include "llvm/IR/MDBuilder.h"
30 #include "llvm/IR/Operator.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Target/TargetOptions.h"
33 #include "llvm/Transforms/Utils/Cloning.h"
34 #include "tensorflow/compiler/xla/debug_options_flags.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
38 #include "tensorflow/compiler/xla/service/dump.h"
39 #include "tensorflow/compiler/xla/service/name_uniquer.h"
40 #include "tensorflow/compiler/xla/shape_util.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/io/path.h"
45 #include "tensorflow/core/platform/byte_order.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/types.h"
49 
50 namespace xla {
51 namespace llvm_ir {
52 
53 namespace {
54 
55 // Note, this function is only useful in an insertion context; in a global
56 // (e.g. constants) context it will CHECK fail.
ModuleFromIRBuilder(llvm::IRBuilder<> * b)57 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
58   auto block = CHECK_NOTNULL(b->GetInsertBlock());
59   auto fn = CHECK_NOTNULL(block->getParent());
60   auto module = CHECK_NOTNULL(fn->getParent());
61   return module;
62 }
63 
64 }  // namespace
65 
DropConstantInitializers(const llvm::Module & module)66 std::unique_ptr<llvm::Module> DropConstantInitializers(
67     const llvm::Module& module) {
68   std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
69   for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
70     global_var.setInitializer(nullptr);
71     global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
72   }
73   return cloned_module;
74 }
75 
DumpModuleToString(const llvm::Module & module)76 string DumpModuleToString(const llvm::Module& module) {
77   std::string buffer_string;
78   llvm::raw_string_ostream ostream(buffer_string);
79   module.print(ostream, nullptr);
80   ostream.flush();
81   return buffer_string;
82 }
83 
EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b,absl::string_view name)84 llvm::CallInst* EmitCallToIntrinsic(
85     llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
86     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
87     absl::string_view name) {
88   llvm::Module* module = ModuleFromIRBuilder(b);
89   llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
90       module, intrinsic_id, AsArrayRef(overloaded_types));
91   return b->CreateCall(intrinsic, AsArrayRef(operands), name.data());
92 }
93 
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)94 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
95                           llvm::IRBuilder<>* b, bool enable_fast_min_max,
96                           absl::string_view name) {
97   if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
98     auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
99     return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
100   } else {
101     auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
102     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
103     auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
104     return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
105   }
106 }
107 
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,bool enable_fast_min_max,absl::string_view name)108 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
109                           llvm::IRBuilder<>* b, bool enable_fast_min_max,
110                           absl::string_view name) {
111   if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) {
112     auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
113     return b->CreateSelect(cmp, lhs_value, rhs_value, name.data());
114   } else {
115     auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
116     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
117     auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
118     return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data());
119   }
120 }
121 
EmitBufferIndexingGEP(llvm::Value * array,llvm::Value * index,llvm::IRBuilder<> * b)122 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
123                                    llvm::IRBuilder<>* b) {
124   llvm::Type* array_type = array->getType();
125   CHECK(array_type->isPointerTy());
126   llvm::PointerType* array_type_as_pointer =
127       llvm::cast<llvm::PointerType>(array_type);
128   VLOG(2) << "EmitBufferIndexingGEP with type="
129           << llvm_ir::DumpToString(*array_type)
130           << " array=" << llvm_ir::DumpToString(*array)
131           << " index=" << llvm_ir::DumpToString(*index);
132 
133   return b->CreateInBoundsGEP(
134       array_type_as_pointer->getElementType(), array,
135       llvm::isa<llvm::GlobalVariable>(array)
136           ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
137           : index);
138 }
139 
EmitBufferIndexingGEP(llvm::Value * array,int64 index,llvm::IRBuilder<> * b)140 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
141                                    llvm::IRBuilder<>* b) {
142   return EmitBufferIndexingGEP(array, b->getInt64(index), b);
143 }
144 
PrimitiveTypeToIrType(PrimitiveType element_type,llvm::Module * module)145 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
146                                   llvm::Module* module) {
147   switch (element_type) {
148     case PRED:
149     case S8:
150     case U8:
151       return llvm::Type::getInt8Ty(module->getContext());
152     case S16:
153     case U16:
154     case BF16:
155       // For BF16 we just need some type that is 16 bits wide so that it will
156       // take up the right amount of space in memory. LLVM does not have a BF16
157       // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
158       // we can't map it directly to an LLVM type. We will not map a BF16
159       // addition to an addition on this type (int16) - this is just the type
160       // used for storage.
161       return llvm::Type::getInt16Ty(module->getContext());
162     case F16:
163       return llvm::Type::getHalfTy(module->getContext());
164     case S32:
165     case U32:
166       return llvm::Type::getInt32Ty(module->getContext());
167     case S64:
168     case U64:
169       return llvm::Type::getInt64Ty(module->getContext());
170     case F32:
171       return llvm::Type::getFloatTy(module->getContext());
172     case F64:
173       return llvm::Type::getDoubleTy(module->getContext());
174     case C64: {
175       auto cplx_t =
176           llvm::StructType::getTypeByName(module->getContext(), "complex64");
177       if (cplx_t == nullptr) {
178         // C++ standard dictates the memory layout of std::complex is contiguous
179         // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
180         // If z is an lvalue expression of type cv std::complex<T> then the
181         // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
182         // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
183         // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
184         // imaginary part of z.
185         return llvm::StructType::create(
186             {llvm::Type::getFloatTy(module->getContext()),
187              llvm::Type::getFloatTy(module->getContext())},
188             "complex64", /*isPacked=*/true);
189       }
190       return cplx_t;
191     }
192     case C128: {
193       auto cplx_t =
194           llvm::StructType::getTypeByName(module->getContext(), "complex128");
195       if (cplx_t == nullptr) {
196         return llvm::StructType::create(
197             {llvm::Type::getDoubleTy(module->getContext()),
198              llvm::Type::getDoubleTy(module->getContext())},
199             "complex128", /*isPacked=*/true);
200       }
201       return cplx_t;
202     }  // A Tuple contains an array of pointers. Use i8*.
203     case TUPLE:
204     // An Opaque is like a void*, use i8*.
205     case OPAQUE_TYPE:
206       return llvm::Type::getInt8PtrTy(module->getContext());
207     case TOKEN:
208       // Tokens do not have a physical representation, but the compiler needs
209       // some placeholder type, so use int8*.
210       return llvm::Type::getInt8PtrTy(module->getContext());
211     default:
212       LOG(FATAL) << "unsupported type " << element_type;
213   }
214 }
215 
GetSizeInBits(llvm::Type * type)216 int GetSizeInBits(llvm::Type* type) {
217   const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
218   if (struct_ty) {
219     CHECK(struct_ty->isPacked());
220     int bits = 0;
221     for (auto element_type : struct_ty->elements()) {
222       bits += GetSizeInBits(element_type);
223     }
224     return bits;
225   }
226   int bits = type->getPrimitiveSizeInBits();
227   CHECK_GT(bits, 0) << "type is not sized";
228   return bits;
229 }
230 
ShapeToIrType(const Shape & shape,llvm::Module * module)231 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
232   llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
233   if (shape.IsTuple()) {
234     // A tuple buffer is an array of pointers.
235     result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
236   } else if (shape.IsArray()) {
237     for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
238       result_type =
239           llvm::ArrayType::get(result_type, shape.dimensions(dimension));
240     }
241   }
242   return result_type;
243 }
244 
EncodeSelfDescribingShapeConstant(const Shape & shape,int32 * shape_size,llvm::IRBuilder<> * b)245 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
246                                                          int32* shape_size,
247                                                          llvm::IRBuilder<>* b) {
248   string encoded_shape = shape.SerializeAsString();
249   if (encoded_shape.size() > std::numeric_limits<int32>::max()) {
250     return InternalError("Encoded shape size exceeded int32 size limit.");
251   }
252   *shape_size = static_cast<int32>(encoded_shape.size());
253   return b->CreateGlobalStringPtr(encoded_shape);
254 }
255 
ConvertLiteralToIrConstant(const Literal & literal,llvm::Module * module)256 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
257                                            llvm::Module* module) {
258   const char* data = static_cast<const char*>(literal.untyped_data());
259   CHECK_EQ(module->getDataLayout().isLittleEndian(),
260            tensorflow::port::kLittleEndian);
261   return llvm::ConstantDataArray::getString(
262       module->getContext(), llvm::StringRef(data, literal.size_bytes()),
263       /*AddNull=*/false);
264 }
265 
AllocateSharedMemoryTile(llvm::Module * module,llvm::Type * tile_type,absl::string_view name)266 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
267                                                llvm::Type* tile_type,
268                                                absl::string_view name) {
269   // Both AMDGPU and NVPTX use the same address space for shared memory.
270   const int kGPUSharedMemoryAddrSpace = 3;
271   return new llvm::GlobalVariable(
272       *module, tile_type,
273       /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
274       llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
275       llvm::GlobalValue::NotThreadLocal, kGPUSharedMemoryAddrSpace);
276 }
277 
EmitAllocaAtFunctionEntry(llvm::Type * type,absl::string_view name,llvm::IRBuilder<> * b,int alignment)278 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
279                                             absl::string_view name,
280                                             llvm::IRBuilder<>* b,
281                                             int alignment) {
282   return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
283 }
284 
EmitAllocaAtFunctionEntryWithCount(llvm::Type * type,llvm::Value * element_count,absl::string_view name,llvm::IRBuilder<> * b,int alignment)285 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
286                                                      llvm::Value* element_count,
287                                                      absl::string_view name,
288                                                      llvm::IRBuilder<>* b,
289                                                      int alignment) {
290   llvm::IRBuilder<>::InsertPointGuard guard(*b);
291   llvm::Function* function = b->GetInsertBlock()->getParent();
292   b->SetInsertPoint(&function->getEntryBlock(),
293                     function->getEntryBlock().getFirstInsertionPt());
294   llvm::AllocaInst* alloca =
295       b->CreateAlloca(type, element_count, AsStringRef(name));
296   if (alignment != 0) {
297     alloca->setAlignment(llvm::Align(alignment));
298   }
299   return alloca;
300 }
301 
CreateBasicBlock(llvm::BasicBlock * insert_before,absl::string_view name,llvm::IRBuilder<> * b)302 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
303                                    absl::string_view name,
304                                    llvm::IRBuilder<>* b) {
305   return llvm::BasicBlock::Create(
306       /*Context=*/b->getContext(),
307       /*Name=*/AsStringRef(name),
308       /*Parent=*/b->GetInsertBlock()->getParent(),
309       /*InsertBefore*/ insert_before);
310 }
311 
EmitIfThenElse(llvm::Value * condition,absl::string_view name,llvm::IRBuilder<> * b,bool emit_else)312 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
313                           llvm::IRBuilder<>* b, bool emit_else) {
314   llvm_ir::LlvmIfData if_data;
315   if_data.if_block = b->GetInsertBlock();
316   if_data.true_block =
317       CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
318   if_data.false_block =
319       emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
320                 : nullptr;
321 
322   // Add a terminator to the if block, if necessary.
323   if (if_data.if_block->getTerminator() == nullptr) {
324     b->SetInsertPoint(if_data.if_block);
325     if_data.after_block =
326         CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
327     b->CreateBr(if_data.after_block);
328   } else {
329     if_data.after_block = if_data.if_block->splitBasicBlock(
330         b->GetInsertPoint(), absl::StrCat(name, "-after"));
331   }
332 
333   // Our basic block should now end with an unconditional branch.  Remove it;
334   // we're going to replace it with a conditional branch.
335   if_data.if_block->getTerminator()->eraseFromParent();
336 
337   b->SetInsertPoint(if_data.if_block);
338   b->CreateCondBr(condition, if_data.true_block,
339                   emit_else ? if_data.false_block : if_data.after_block);
340 
341   b->SetInsertPoint(if_data.true_block);
342   b->CreateBr(if_data.after_block);
343 
344   if (emit_else) {
345     b->SetInsertPoint(if_data.false_block);
346     b->CreateBr(if_data.after_block);
347   }
348 
349   b->SetInsertPoint(if_data.after_block,
350                     if_data.after_block->getFirstInsertionPt());
351 
352   return if_data;
353 }
354 
EmitComparison(llvm::CmpInst::Predicate predicate,llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b,absl::string_view name)355 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
356                             llvm::Value* lhs_value, llvm::Value* rhs_value,
357                             llvm::IRBuilder<>* b, absl::string_view name) {
358   llvm::Value* comparison_result;
359   if (lhs_value->getType()->isIntegerTy()) {
360     comparison_result =
361         b->CreateICmp(predicate, lhs_value, rhs_value, name.data());
362   } else {
363     comparison_result =
364         b->CreateFCmp(predicate, lhs_value, rhs_value, name.data());
365   }
366   // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
367   // arrays. So we extend it to i8 so that it's addressable.
368   return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
369                                               PRED, ModuleFromIRBuilder(b)));
370 }
371 
372 // Internal helper that is called from emitted code to log an int64 value with a
373 // tag.
LogS64(const char * tag,int64 value)374 static void LogS64(const char* tag, int64 value) {
375   LOG(INFO) << tag << " (int64): " << value;
376 }
377 
EmitLogging(const char * tag,llvm::Value * value,llvm::IRBuilder<> * b)378 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
379   llvm::FunctionType* log_function_type = llvm::FunctionType::get(
380       b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
381   b->CreateCall(log_function_type,
382                 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64>(&LogS64)),
383                                   log_function_type->getPointerTo()),
384                 {b->getInt64(absl::bit_cast<int64>(tag)), value});
385 }
386 
SetAlignmentMetadataForLoad(llvm::LoadInst * load,uint64_t alignment)387 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
388   llvm::LLVMContext& context = load->getContext();
389   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
390   llvm::Constant* alignment_constant =
391       llvm::ConstantInt::get(int64_ty, alignment);
392   llvm::MDBuilder metadata_builder(context);
393   auto* alignment_metadata =
394       metadata_builder.createConstant(alignment_constant);
395   load->setMetadata(llvm::LLVMContext::MD_align,
396                     llvm::MDNode::get(context, alignment_metadata));
397 }
398 
SetDereferenceableMetadataForLoad(llvm::LoadInst * load,uint64_t dereferenceable_bytes)399 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
400                                        uint64_t dereferenceable_bytes) {
401   llvm::LLVMContext& context = load->getContext();
402   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
403   llvm::Constant* dereferenceable_bytes_constant =
404       llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
405   llvm::MDBuilder metadata_builder(context);
406   auto* dereferenceable_bytes_metadata =
407       metadata_builder.createConstant(dereferenceable_bytes_constant);
408   load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
409                     llvm::MDNode::get(context, dereferenceable_bytes_metadata));
410 }
411 
AddRangeMetadata(int64 lower,int64 upper,llvm::Instruction * inst)412 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
413                                     llvm::Instruction* inst) {
414   llvm::LLVMContext& context = inst->getParent()->getContext();
415   llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
416   inst->setMetadata(
417       llvm::LLVMContext::MD_range,
418       llvm::MDNode::get(
419           context,
420           {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
421            llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
422   return inst;
423 }
424 
IrName(absl::string_view a)425 string IrName(absl::string_view a) {
426   std::string s(a);
427   s.erase(std::remove(s.begin(), s.end(), '%'), s.end());
428   return s;
429 }
430 
IrName(absl::string_view a,absl::string_view b)431 string IrName(absl::string_view a, absl::string_view b) {
432   if (!a.empty() && !b.empty()) {
433     return IrName(absl::StrCat(a, ".", b));
434   }
435   return IrName(absl::StrCat(a, b));
436 }
437 
IrName(const HloInstruction * a,absl::string_view b)438 string IrName(const HloInstruction* a, absl::string_view b) {
439   return IrName(a->name(), b);
440 }
441 
SanitizeFunctionName(string function_name)442 string SanitizeFunctionName(string function_name) {
443   // The backend with the strictest requirements on function names is NVPTX, so
444   // we sanitize to its requirements.
445   //
446   // A slightly stricter version of the NVPTX requirements is that names match
447   // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
448   // are illegal.
449 
450   // Sanitize chars in function_name.
451   std::transform(function_name.begin(), function_name.end(),
452                  function_name.begin(), [](char c) {
453                    if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
454                        ('0' <= c && c <= '9') || c == '_' || c == '$') {
455                      return c;
456                    }
457                    return '_';
458                  });
459 
460   // Ensure the name isn't empty.
461   if (function_name.empty()) {
462     function_name = "__unnamed";
463   }
464 
465   // Ensure the name doesn't start with a number.
466   if (!function_name.empty() && function_name[0] >= '0' &&
467       function_name[0] <= '9') {
468     function_name.insert(function_name.begin(), '_');
469   }
470 
471   // Ensure the name isn't "_" or "$".
472   if (function_name == "_" || function_name == "$") {
473     function_name += '_';
474   }
475 
476   return function_name;
477 }
478 
SetToFirstInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)479 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
480   builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
481 }
482 
SetToLastInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)483 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
484   if (llvm::Instruction* terminator = blk->getTerminator()) {
485     builder->SetInsertPoint(terminator);
486   } else {
487     builder->SetInsertPoint(blk);
488   }
489 }
490 
CreateRor(llvm::Value * rotand,llvm::Value * rotor,llvm::IRBuilder<> * builder)491 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
492                        llvm::IRBuilder<>* builder) {
493   auto size = rotand->getType()->getPrimitiveSizeInBits();
494   auto size_value = builder->getIntN(size, size);
495   auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
496   return builder->CreateOr(
497       builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
498       builder->CreateLShr(rotand, mod(rotor)));
499 }
500 
ByteSizeOf(const Shape & shape,const llvm::DataLayout & data_layout)501 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
502   unsigned pointer_size = data_layout.getPointerSize();
503   return ShapeUtil::ByteSizeOf(shape, pointer_size);
504 }
505 
GetCpuFastMathFlags(const HloModuleConfig & module_config)506 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
507   llvm::FastMathFlags flags;
508   const auto& options = module_config.debug_options();
509   if (!options.xla_cpu_enable_fast_math()) {
510     return flags;
511   }
512   // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
513   // AllowContract, and ApproxFunc.
514   flags.setFast();
515   flags.setNoNaNs(!options.xla_cpu_fast_math_honor_nans());
516   flags.setNoInfs(!options.xla_cpu_fast_math_honor_infs());
517   flags.setAllowReciprocal(!options.xla_cpu_fast_math_honor_division());
518   flags.setApproxFunc(!options.xla_cpu_fast_math_honor_functions());
519   return flags;
520 }
521 
MergeMetadata(llvm::LLVMContext * context,const std::map<int,llvm::MDNode * > & a,const std::map<int,llvm::MDNode * > & b)522 std::map<int, llvm::MDNode*> MergeMetadata(
523     llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
524     const std::map<int, llvm::MDNode*>& b) {
525   // We should extend this as needed to deal with other kinds of metadata like
526   // !dereferenceable and !range.
527 
528   std::map<int, llvm::MDNode*> result;
529   for (auto kind_md_pair : a) {
530     if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
531       llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
532       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
533       for (const auto& scope_a : kind_md_pair.second->operands()) {
534         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
535         union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
536       }
537       auto it = b.find(kind_md_pair.first);
538       if (it != b.end()) {
539         for (const auto& scope_b : it->second->operands()) {
540           if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
541             union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
542           }
543         }
544       }
545       result[llvm::LLVMContext::MD_alias_scope] =
546           llvm::MDNode::get(*context, union_of_scopes);
547     } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
548       llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
549       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
550       for (const auto& scope_a : kind_md_pair.second->operands()) {
551         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
552       }
553       auto it = b.find(kind_md_pair.first);
554       if (it != b.end()) {
555         for (const auto& scope_b : it->second->operands()) {
556           if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
557             intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
558           }
559         }
560       }
561       if (!intersection_of_scopes.empty()) {
562         result[llvm::LLVMContext::MD_noalias] =
563             llvm::MDNode::get(*context, intersection_of_scopes);
564       }
565     }
566   }
567   return result;
568 }
569 
CreateAndWriteStringToFile(const string & directory_name,const string & file_name,const string & text)570 static Status CreateAndWriteStringToFile(const string& directory_name,
571                                          const string& file_name,
572                                          const string& text) {
573   std::unique_ptr<tensorflow::WritableFile> f;
574   TF_RETURN_IF_ERROR(
575       tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
576   TF_RETURN_IF_ERROR(
577       tensorflow::Env::Default()->NewWritableFile(file_name, &f));
578   TF_RETURN_IF_ERROR(f->Append(text));
579   TF_RETURN_IF_ERROR(f->Close());
580   return Status::OK();
581 }
582 
DumpIrIfEnabled(const HloModule & hlo_module,const llvm::Module & llvm_module,bool optimized,absl::string_view filename_suffix)583 void DumpIrIfEnabled(const HloModule& hlo_module,
584                      const llvm::Module& llvm_module, bool optimized,
585                      absl::string_view filename_suffix) {
586   const auto& debug_opts = hlo_module.config().debug_options();
587   if (!DumpingEnabledForHloModule(hlo_module)) {
588     return;
589   }
590   // We can end up compiling different modules with the same name when using
591   // XlaJitCompiledCpuFunction::Compile.  Avoid overwriting IR files previously
592   // dumped from the same process in such cases.
593   string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt");
594   DumpToFileInDirOrStdout(
595       hlo_module, "",
596       absl::StrCat(suffix, filename_suffix.empty() ? "" : ".", filename_suffix,
597                    ".ll"),
598       DumpModuleToString(llvm_module));
599 
600   // For some models the embedded constants can be huge, so also dump the module
601   // with the constants stripped to get IR that is easier to manipulate.  Skip
602   // this if we're dumping to stdout; there's no point in duplicating everything
603   // when writing to the terminal.
604   if (!DumpingToStdout(debug_opts)) {
605     DumpToFileInDir(hlo_module, "", absl::StrCat(suffix, "-noconst.ll"),
606                     DumpModuleToString(*DropConstantInitializers(llvm_module)));
607   }
608 }
609 
CreateCpuFunction(llvm::FunctionType * function_type,llvm::GlobalValue::LinkageTypes linkage,const HloModuleConfig & module_config,absl::string_view name,llvm::Module * module)610 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
611                                   llvm::GlobalValue::LinkageTypes linkage,
612                                   const HloModuleConfig& module_config,
613                                   absl::string_view name,
614                                   llvm::Module* module) {
615   llvm::Function* function =
616       llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
617   function->setCallingConv(llvm::CallingConv::C);
618   function->addFnAttr("no-frame-pointer-elim", "false");
619 
620   // Generate unwind information so that GDB can crawl through the stack frames
621   // created by the JIT compiled code.
622   function->setHasUWTable();
623 
624   // Tensorflow always flushes denormals to zero, let LLVM know that flushing
625   // denormals is safe. This allows vectorization using ARM's neon instruction
626   // set.
627   function->addFnAttr("denormal-fp-math", "preserve-sign");
628 
629   // Add the optimize attribute to the function if optimizing for size. This
630   // controls internal behavior of some optimization passes (e.g. loop
631   // unrolling).
632   if (cpu::options::OptimizeForSizeRequested(module_config)) {
633     function->addFnAttr(llvm::Attribute::OptimizeForSize);
634   }
635 
636   return function;
637 }
638 
InitializeLLVMCommandLineOptions(const HloModuleConfig & config)639 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
640   auto options = config.debug_options().xla_backend_extra_options();
641   if (!options.empty()) {
642     std::vector<string> fake_argv_storage;
643     fake_argv_storage.push_back("");
644     for (const auto& it : options) {
645       // Skip options the XLA backend itself consumes.
646       if (!absl::StartsWith(it.first, "xla_")) {
647         if (it.second.empty()) {
648           fake_argv_storage.push_back(it.first);
649         } else {
650           fake_argv_storage.push_back(it.first + "=" + it.second);
651         }
652       }
653     }
654 
655     VLOG(2) << "Passing argv to LLVM:";
656     std::vector<const char*> fake_argv;
657     for (const auto& s : fake_argv_storage) {
658       fake_argv.push_back(s.c_str());
659       VLOG(2) << s;
660     }
661     llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
662   }
663 }
664 
UMulLowHigh32(llvm::IRBuilder<> * b,llvm::Value * src0,llvm::Value * src1)665 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
666                                                     llvm::Value* src0,
667                                                     llvm::Value* src1) {
668   CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
669   CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
670   llvm::Type* int64_ty = b->getInt64Ty();
671   src0 = b->CreateZExt(src0, int64_ty);
672   src1 = b->CreateZExt(src1, int64_ty);
673   return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
674 }
675 
SplitInt64ToInt32s(llvm::IRBuilder<> * b,llvm::Value * value_64bits)676 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
677     llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
678   CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
679   llvm::Type* int32_ty = b->getInt32Ty();
680   llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
681   llvm::Value* high_32bits =
682       b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
683   return std::make_pair(low_32bits, high_32bits);
684 }
685 
GetGlobalMemoryAddressSpace(const llvm::Module & module)686 unsigned GetGlobalMemoryAddressSpace(const llvm::Module& module) {
687   const unsigned kAMDGPUGlobalMemoryAddrSpace = 1;
688   llvm::Triple target_triple = llvm::Triple(module.getTargetTriple());
689   if (target_triple.getArch() == llvm::Triple::amdgcn) {
690     // AMDGPU uses 1 for global memory address space.
691     return kAMDGPUGlobalMemoryAddrSpace;
692   }
693   return 0;
694 }
695 
GetOrCreateVariableForRngState(llvm::Module * module,llvm::IRBuilder<> * b)696 llvm::GlobalVariable* GetOrCreateVariableForRngState(llvm::Module* module,
697                                                      llvm::IRBuilder<>* b) {
698   static const char* kRngStateVariableName = "rng_state";
699   llvm::GlobalVariable* state_ptr =
700       module->getNamedGlobal(kRngStateVariableName);
701   if (!state_ptr) {
702     unsigned global_address_space = GetGlobalMemoryAddressSpace(*module);
703     llvm::Type* state_type = b->getInt128Ty();
704     // Use a non-zero initial value as zero state can cause the result of the
705     // first random number generation not passing the chi-square test. The
706     // values used here are arbitrarily chosen, any non-zero values should be
707     // fine.
708     state_ptr = new llvm::GlobalVariable(
709         /*M=*/*module,
710         /*Ty=*/state_type,
711         /*isConstant=*/false,
712         /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
713         /*Initializer=*/llvm::ConstantInt::get(b->getInt128Ty(), 0x7012395ull),
714         /*Name=*/kRngStateVariableName,
715         /*InsertBefore=*/nullptr,
716         /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
717         /*AddressSpace=*/global_address_space,
718         /*isExternallyInitialized=*/false);
719   }
720   return state_ptr;
721 }
722 
RngGetAndUpdateState(uint64 delta,llvm::Module * module,llvm::IRBuilder<> * builder)723 llvm::Value* RngGetAndUpdateState(uint64 delta, llvm::Module* module,
724                                   llvm::IRBuilder<>* builder) {
725   llvm::GlobalVariable* state_ptr =
726       GetOrCreateVariableForRngState(module, builder);
727   llvm::LoadInst* state_value_old =
728       builder->CreateLoad(state_ptr, "load_state");
729   llvm::Value* state_value_new = builder->CreateAdd(
730       state_value_old,
731       llvm::ConstantInt::get(state_value_old->getType(), delta));
732   builder->CreateStore(state_value_new, state_ptr);
733   return state_value_old;
734 }
735 
736 }  // namespace llvm_ir
737 }  // namespace xla
738