1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/base/casts.h"
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/GlobalValue.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/MDBuilder.h"
29 #include "llvm/IR/Operator.h"
30 #include "llvm/Target/TargetOptions.h"
31 #include "llvm/Transforms/Utils/Cloning.h"
32 #include "tensorflow/compiler/xla/layout_util.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
35 #include "tensorflow/compiler/xla/service/dump.h"
36 #include "tensorflow/compiler/xla/service/name_uniquer.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/types.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/lib/io/path.h"
42 #include "tensorflow/core/platform/byte_order.h"
43 #include "tensorflow/core/platform/env.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/types.h"
46 
47 namespace xla {
48 namespace llvm_ir {
49 
50 namespace {
51 
52 // Note, this function is only useful in an insertion context; in a global
53 // (e.g. constants) context it will CHECK fail.
ModuleFromIRBuilder(llvm::IRBuilder<> * b)54 llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* b) {
55   auto block = CHECK_NOTNULL(b->GetInsertBlock());
56   auto fn = CHECK_NOTNULL(block->getParent());
57   auto module = CHECK_NOTNULL(fn->getParent());
58   return module;
59 }
60 
61 }  // namespace
62 
DropConstantInitializers(const llvm::Module & module)63 std::unique_ptr<llvm::Module> DropConstantInitializers(
64     const llvm::Module& module) {
65   std::unique_ptr<llvm::Module> cloned_module = CloneModule(module);
66   for (llvm::GlobalVariable& global_var : cloned_module->globals()) {
67     global_var.setInitializer(nullptr);
68     global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage);
69   }
70   return cloned_module;
71 }
72 
DumpModuleToString(const llvm::Module & module)73 string DumpModuleToString(const llvm::Module& module) {
74   std::string buffer_string;
75   llvm::raw_string_ostream ostream(buffer_string);
76   module.print(ostream, nullptr);
77   ostream.flush();
78   return buffer_string;
79 }
80 
EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id,absl::Span<llvm::Value * const> operands,absl::Span<llvm::Type * const> overloaded_types,llvm::IRBuilder<> * b)81 llvm::CallInst* EmitCallToIntrinsic(
82     llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
83     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b) {
84   llvm::Module* module = ModuleFromIRBuilder(b);
85   llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
86       module, intrinsic_id, AsArrayRef(overloaded_types));
87   return b->CreateCall(intrinsic, AsArrayRef(operands));
88 }
89 
EmitFloatMax(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)90 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
91                           llvm::IRBuilder<>* b) {
92   if (b->getFastMathFlags().noNaNs()) {
93     auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value);
94     return b->CreateSelect(cmp, lhs_value, rhs_value);
95   } else {
96     auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value);
97     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
98     auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan);
99     return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
100   }
101 }
102 
EmitFloatMin(llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)103 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
104                           llvm::IRBuilder<>* b) {
105   if (b->getFastMathFlags().noNaNs()) {
106     auto cmp = b->CreateFCmpULE(lhs_value, rhs_value);
107     return b->CreateSelect(cmp, lhs_value, rhs_value);
108   } else {
109     auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value);
110     auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value);
111     auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan);
112     return b->CreateSelect(sel_lhs, lhs_value, rhs_value);
113   }
114 }
115 
EmitBufferIndexingGEP(llvm::Value * array,llvm::Value * index,llvm::IRBuilder<> * b)116 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
117                                    llvm::IRBuilder<>* b) {
118   llvm::Type* array_type = array->getType();
119   CHECK(array_type->isPointerTy());
120   llvm::PointerType* array_type_as_pointer =
121       llvm::cast<llvm::PointerType>(array_type);
122   VLOG(2) << "EmitBufferIndexingGEP with type="
123           << llvm_ir::DumpToString(*array_type)
124           << " array=" << llvm_ir::DumpToString(*array)
125           << " index=" << llvm_ir::DumpToString(*index);
126 
127   return b->CreateInBoundsGEP(
128       array_type_as_pointer->getElementType(), array,
129       llvm::isa<llvm::GlobalVariable>(array)
130           ? llvm::ArrayRef<llvm::Value*>({b->getInt64(0), index})
131           : index);
132 }
133 
EmitBufferIndexingGEP(llvm::Value * array,int64 index,llvm::IRBuilder<> * b)134 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
135                                    llvm::IRBuilder<>* b) {
136   return EmitBufferIndexingGEP(array, b->getInt64(index), b);
137 }
138 
PrimitiveTypeToIrType(PrimitiveType element_type,llvm::Module * module)139 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
140                                   llvm::Module* module) {
141   switch (element_type) {
142     case PRED:
143     case S8:
144     case U8:
145       return llvm::Type::getInt8Ty(module->getContext());
146     case S16:
147     case U16:
148     case BF16:
149       // For BF16 we just need some type that is 16 bits wide so that it will
150       // take up the right amount of space in memory. LLVM does not have a BF16
151       // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so
152       // we can't map it directly to an LLVM type. We will not map a BF16
153       // addition to an addition on this type (int16) - this is just the type
154       // used for storage.
155       return llvm::Type::getInt16Ty(module->getContext());
156     case F16:
157       return llvm::Type::getHalfTy(module->getContext());
158     case S32:
159     case U32:
160       return llvm::Type::getInt32Ty(module->getContext());
161     case S64:
162     case U64:
163       return llvm::Type::getInt64Ty(module->getContext());
164     case F32:
165       return llvm::Type::getFloatTy(module->getContext());
166     case F64:
167       return llvm::Type::getDoubleTy(module->getContext());
168     case C64: {
169       auto cplx_t = module->getTypeByName("complex64");
170       if (cplx_t == nullptr) {
171         // C++ standard dictates the memory layout of std::complex is contiguous
172         // real followed by imaginary. C++11 section 26.4 [complex.numbers]:
173         // If z is an lvalue expression of type cv std::complex<T> then the
174         // expression reinterpret_cast<cv T(&)[2]>(z) shall be well-formed,
175         // reinterpret_cast<cv T(&)[2]>(z)[0] shall designate the real part of
176         // z, and reinterpret_cast<cv T(&)[2]>(z)[1] shall designate the
177         // imaginary part of z.
178         return llvm::StructType::create(
179             {llvm::Type::getFloatTy(module->getContext()),
180              llvm::Type::getFloatTy(module->getContext())},
181             "complex64", /*isPacked=*/true);
182       }
183       return cplx_t;
184     }
185     case C128: {
186       auto cplx_t = module->getTypeByName("complex128");
187       if (cplx_t == nullptr) {
188         return llvm::StructType::create(
189             {llvm::Type::getDoubleTy(module->getContext()),
190              llvm::Type::getDoubleTy(module->getContext())},
191             "complex128", /*isPacked=*/true);
192       }
193       return cplx_t;
194     }  // A Tuple contains an array of pointers. Use i8*.
195     case TUPLE:
196     // An Opaque is like a void*, use i8*.
197     case OPAQUE:
198       return llvm::Type::getInt8PtrTy(module->getContext());
199     case TOKEN:
200       // Tokens do not have a physical representation, but the compiler needs
201       // some placeholder type, so use int8*.
202       return llvm::Type::getInt8PtrTy(module->getContext());
203     default:
204       LOG(FATAL) << "unsupported type " << element_type;
205   }
206 }
207 
GetSizeInBits(llvm::Type * type)208 int GetSizeInBits(llvm::Type* type) {
209   const llvm::StructType* struct_ty = llvm::dyn_cast<llvm::StructType>(type);
210   if (struct_ty) {
211     CHECK(struct_ty->isPacked());
212     int bits = 0;
213     for (auto element_type : struct_ty->elements()) {
214       bits += GetSizeInBits(element_type);
215     }
216     return bits;
217   }
218   int bits = type->getPrimitiveSizeInBits();
219   CHECK_GT(bits, 0) << "type is not sized";
220   return bits;
221 }
222 
ShapeToIrType(const Shape & shape,llvm::Module * module)223 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) {
224   llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module);
225   if (shape.IsTuple()) {
226     // A tuple buffer is an array of pointers.
227     result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
228   } else if (shape.IsArray()) {
229     for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
230       result_type =
231           llvm::ArrayType::get(result_type, shape.dimensions(dimension));
232     }
233   }
234   return result_type;
235 }
236 
EncodeSelfDescribingShapeConstant(const Shape & shape,int32 * shape_size,llvm::IRBuilder<> * b)237 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
238                                                          int32* shape_size,
239                                                          llvm::IRBuilder<>* b) {
240   string encoded_shape = shape.SerializeAsString();
241   if (encoded_shape.size() > std::numeric_limits<int32>::max()) {
242     return InternalError("Encoded shape size exceeded int32 size limit.");
243   }
244   *shape_size = static_cast<int32>(encoded_shape.size());
245   return b->CreateGlobalStringPtr(encoded_shape);
246 }
247 
DecodeSelfDescribingShapeConstant(const void * shape_ptr,int32 size_bytes)248 StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
249                                                   int32 size_bytes) {
250   ShapeProto shape_proto;
251   TF_RET_CHECK(shape_proto.ParseFromArray(shape_ptr, size_bytes));
252   Shape shape(shape_proto);
253   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
254   return std::move(shape);
255 }
256 
ConvertLiteralToIrConstant(const Literal & literal,llvm::Module * module)257 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
258                                            llvm::Module* module) {
259   const char* data = static_cast<const char*>(literal.untyped_data());
260   CHECK_EQ(module->getDataLayout().isLittleEndian(),
261            tensorflow::port::kLittleEndian);
262   return llvm::ConstantDataArray::getString(
263       module->getContext(), llvm::StringRef(data, literal.size_bytes()),
264       /*AddNull=*/false);
265 }
266 
AllocateSharedMemoryTile(llvm::Module * module,llvm::Type * tile_type,absl::string_view name)267 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
268                                                llvm::Type* tile_type,
269                                                absl::string_view name) {
270   const int kNVPTXSharedMemoryAddrSpace = 3;
271   return new llvm::GlobalVariable(
272       *module, tile_type,
273       /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
274       llvm::UndefValue::get(tile_type), AsStringRef(name), nullptr,
275       llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace);
276 }
277 
EmitAllocaAtFunctionEntry(llvm::Type * type,absl::string_view name,llvm::IRBuilder<> * b,int alignment)278 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
279                                             absl::string_view name,
280                                             llvm::IRBuilder<>* b,
281                                             int alignment) {
282   return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment);
283 }
284 
EmitAllocaAtFunctionEntryWithCount(llvm::Type * type,llvm::Value * element_count,absl::string_view name,llvm::IRBuilder<> * b,int alignment)285 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
286                                                      llvm::Value* element_count,
287                                                      absl::string_view name,
288                                                      llvm::IRBuilder<>* b,
289                                                      int alignment) {
290   llvm::IRBuilder<>::InsertPointGuard guard(*b);
291   llvm::Function* function = b->GetInsertBlock()->getParent();
292   b->SetInsertPoint(&function->getEntryBlock(),
293                     function->getEntryBlock().getFirstInsertionPt());
294   llvm::AllocaInst* alloca =
295       b->CreateAlloca(type, element_count, AsStringRef(name));
296   if (alignment != 0) {
297     alloca->setAlignment(alignment);
298   }
299   return alloca;
300 }
301 
CreateBasicBlock(llvm::BasicBlock * insert_before,absl::string_view name,llvm::IRBuilder<> * b)302 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
303                                    absl::string_view name,
304                                    llvm::IRBuilder<>* b) {
305   return llvm::BasicBlock::Create(
306       /*Context=*/b->getContext(),
307       /*Name=*/AsStringRef(name),
308       /*Parent=*/b->GetInsertBlock()->getParent(),
309       /*InsertBefore*/ insert_before);
310 }
311 
EmitIfThenElse(llvm::Value * condition,absl::string_view name,llvm::IRBuilder<> * b,bool emit_else)312 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
313                           llvm::IRBuilder<>* b, bool emit_else) {
314   llvm_ir::LlvmIfData if_data;
315   if_data.if_block = b->GetInsertBlock();
316   if_data.true_block =
317       CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b);
318   if_data.false_block =
319       emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b)
320                 : nullptr;
321 
322   // Add a terminator to the if block, if necessary.
323   if (if_data.if_block->getTerminator() == nullptr) {
324     b->SetInsertPoint(if_data.if_block);
325     if_data.after_block =
326         CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b);
327     b->CreateBr(if_data.after_block);
328   } else {
329     if_data.after_block = if_data.if_block->splitBasicBlock(
330         b->GetInsertPoint(), absl::StrCat(name, "-after"));
331   }
332 
333   // Our basic block should now end with an unconditional branch.  Remove it;
334   // we're going to replace it with a conditional branch.
335   if_data.if_block->getTerminator()->eraseFromParent();
336 
337   b->SetInsertPoint(if_data.if_block);
338   b->CreateCondBr(condition, if_data.true_block,
339                   emit_else ? if_data.false_block : if_data.after_block);
340 
341   b->SetInsertPoint(if_data.true_block);
342   b->CreateBr(if_data.after_block);
343 
344   if (emit_else) {
345     b->SetInsertPoint(if_data.false_block);
346     b->CreateBr(if_data.after_block);
347   }
348 
349   b->SetInsertPoint(if_data.after_block,
350                     if_data.after_block->getFirstInsertionPt());
351 
352   return if_data;
353 }
354 
EmitComparison(llvm::CmpInst::Predicate predicate,llvm::Value * lhs_value,llvm::Value * rhs_value,llvm::IRBuilder<> * b)355 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
356                             llvm::Value* lhs_value, llvm::Value* rhs_value,
357                             llvm::IRBuilder<>* b) {
358   llvm::Value* comparison_result;
359   if (lhs_value->getType()->isIntegerTy()) {
360     comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value);
361   } else {
362     comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value);
363   }
364   // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
365   // arrays. So we extend it to i8 so that it's addressable.
366   return b->CreateZExt(comparison_result, llvm_ir::PrimitiveTypeToIrType(
367                                               PRED, ModuleFromIRBuilder(b)));
368 }
369 
370 // Internal helper that is called from emitted code to log an int64 value with a
371 // tag.
LogS64(const char * tag,int64 value)372 static void LogS64(const char* tag, int64 value) {
373   LOG(INFO) << tag << " (int64): " << value;
374 }
375 
EmitLogging(const char * tag,llvm::Value * value,llvm::IRBuilder<> * b)376 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
377   llvm::FunctionType* log_function_type = llvm::FunctionType::get(
378       b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
379   b->CreateCall(log_function_type,
380                 b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64>(&LogS64)),
381                                   log_function_type->getPointerTo()),
382                 {b->getInt64(absl::bit_cast<int64>(tag)), value});
383 }
384 
SetAlignmentMetadataForLoad(llvm::LoadInst * load,uint64_t alignment)385 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
386   llvm::LLVMContext& context = load->getContext();
387   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
388   llvm::Constant* alignment_constant =
389       llvm::ConstantInt::get(int64_ty, alignment);
390   llvm::MDBuilder metadata_builder(context);
391   auto* alignment_metadata =
392       metadata_builder.createConstant(alignment_constant);
393   load->setMetadata(llvm::LLVMContext::MD_align,
394                     llvm::MDNode::get(context, alignment_metadata));
395 }
396 
SetDereferenceableMetadataForLoad(llvm::LoadInst * load,uint64_t dereferenceable_bytes)397 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
398                                        uint64_t dereferenceable_bytes) {
399   llvm::LLVMContext& context = load->getContext();
400   llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
401   llvm::Constant* dereferenceable_bytes_constant =
402       llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
403   llvm::MDBuilder metadata_builder(context);
404   auto* dereferenceable_bytes_metadata =
405       metadata_builder.createConstant(dereferenceable_bytes_constant);
406   load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
407                     llvm::MDNode::get(context, dereferenceable_bytes_metadata));
408 }
409 
AddRangeMetadata(int64 lower,int64 upper,llvm::Instruction * inst)410 llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
411                                     llvm::Instruction* inst) {
412   llvm::LLVMContext& context = inst->getParent()->getContext();
413   llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
414   inst->setMetadata(
415       llvm::LLVMContext::MD_range,
416       llvm::MDNode::get(
417           context,
418           {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
419            llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
420   return inst;
421 }
422 
IrName(string a)423 string IrName(string a) {
424   a.erase(std::remove(a.begin(), a.end(), '%'), a.end());
425   return a;
426 }
427 
IrName(absl::string_view a,absl::string_view b)428 string IrName(absl::string_view a, absl::string_view b) {
429   if (!a.empty() && !b.empty()) {
430     return IrName(absl::StrCat(a, ".", b));
431   }
432   return IrName(absl::StrCat(a, b));
433 }
434 
IrName(const HloInstruction * a,absl::string_view b)435 string IrName(const HloInstruction* a, absl::string_view b) {
436   return IrName(a->name(), b);
437 }
438 
SanitizeFunctionName(string function_name)439 string SanitizeFunctionName(string function_name) {
440   // The backend with the strictest requirements on function names is NVPTX, so
441   // we sanitize to its requirements.
442   //
443   // A slightly stricter version of the NVPTX requirements is that names match
444   // /[a-zA-Z_$][a-zA-Z0-9_$]*/, with the exception that the names "_" and "$"
445   // are illegal.
446 
447   // Sanitize chars in function_name.
448   std::transform(function_name.begin(), function_name.end(),
449                  function_name.begin(), [](char c) {
450                    if (('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') ||
451                        ('0' <= c && c <= '9') || c == '_' || c == '$') {
452                      return c;
453                    }
454                    return '_';
455                  });
456 
457   // Ensure the name isn't empty.
458   if (function_name.empty()) {
459     function_name = "__unnamed";
460   }
461 
462   // Ensure the name doesn't start with a number.
463   if (!function_name.empty() && function_name[0] >= '0' &&
464       function_name[0] <= '9') {
465     function_name.insert(function_name.begin(), '_');
466   }
467 
468   // Ensure the name isn't "_" or "$".
469   if (function_name == "_" || function_name == "$") {
470     function_name += '_';
471   }
472 
473   return function_name;
474 }
475 
SetToFirstInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)476 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
477   builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
478 }
479 
SetToLastInsertPoint(llvm::BasicBlock * blk,llvm::IRBuilder<> * builder)480 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
481   if (llvm::Instruction* terminator = blk->getTerminator()) {
482     builder->SetInsertPoint(terminator);
483   } else {
484     builder->SetInsertPoint(blk);
485   }
486 }
487 
CreateRor(llvm::Value * rotand,llvm::Value * rotor,llvm::IRBuilder<> * builder)488 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
489                        llvm::IRBuilder<>* builder) {
490   auto size = rotand->getType()->getPrimitiveSizeInBits();
491   auto size_value = builder->getIntN(size, size);
492   auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
493   return builder->CreateOr(
494       builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
495       builder->CreateLShr(rotand, mod(rotor)));
496 }
497 
ByteSizeOf(const Shape & shape,const llvm::DataLayout & data_layout)498 int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
499   unsigned pointer_size = data_layout.getPointerSize();
500   return ShapeUtil::ByteSizeOf(shape, pointer_size);
501 }
502 
GetCpuFastMathFlags(const HloModuleConfig & module_config)503 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) {
504   llvm::FastMathFlags flags;
505   if (!module_config.debug_options().xla_cpu_enable_fast_math()) {
506     return flags;
507   }
508 
509   // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, AllowReciprocal,
510   // AllowContract, and ApproxFunc.
511   flags.setFast();
512 
513   if (module_config.debug_options().xla_cpu_fast_math_honor_nans()) {
514     flags.setNoNaNs(false);
515   }
516 
517   if (module_config.debug_options().xla_cpu_fast_math_honor_infs()) {
518     flags.setNoInfs(false);
519   }
520 
521   return flags;
522 }
523 
MergeMetadata(llvm::LLVMContext * context,const std::map<int,llvm::MDNode * > & a,const std::map<int,llvm::MDNode * > & b)524 std::map<int, llvm::MDNode*> MergeMetadata(
525     llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
526     const std::map<int, llvm::MDNode*>& b) {
527   // We should extend this as needed to deal with other kinds of metadata like
528   // !dereferenceable and !range.
529 
530   std::map<int, llvm::MDNode*> result;
531   for (auto kind_md_pair : a) {
532     if (kind_md_pair.first == llvm::LLVMContext::MD_alias_scope) {
533       llvm::SmallVector<llvm::Metadata*, 8> union_of_scopes;
534       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
535       for (const auto& scope_a : kind_md_pair.second->operands()) {
536         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
537         union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_a.get()));
538       }
539       auto it = b.find(kind_md_pair.first);
540       if (it != b.end()) {
541         for (const auto& scope_b : it->second->operands()) {
542           if (!scope_set.count(llvm::cast<llvm::MDNode>(scope_b.get()))) {
543             union_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b.get()));
544           }
545         }
546       }
547       result[llvm::LLVMContext::MD_alias_scope] =
548           llvm::MDNode::get(*context, union_of_scopes);
549     } else if (kind_md_pair.first == llvm::LLVMContext::MD_noalias) {
550       llvm::SmallVector<llvm::Metadata*, 8> intersection_of_scopes;
551       llvm::SmallPtrSet<llvm::Metadata*, 8> scope_set;
552       for (const auto& scope_a : kind_md_pair.second->operands()) {
553         scope_set.insert(llvm::cast<llvm::MDNode>(scope_a.get()));
554       }
555       auto it = b.find(kind_md_pair.first);
556       if (it != b.end()) {
557         for (const auto& scope_b : it->second->operands()) {
558           if (scope_set.count(llvm::cast<llvm::MDNode>(scope_b))) {
559             intersection_of_scopes.push_back(llvm::cast<llvm::MDNode>(scope_b));
560           }
561         }
562       }
563       if (!intersection_of_scopes.empty()) {
564         result[llvm::LLVMContext::MD_noalias] =
565             llvm::MDNode::get(*context, intersection_of_scopes);
566       }
567     }
568   }
569   return result;
570 }
571 
CreateAndWriteStringToFile(const string & directory_name,const string & file_name,const string & text)572 static Status CreateAndWriteStringToFile(const string& directory_name,
573                                          const string& file_name,
574                                          const string& text) {
575   std::unique_ptr<tensorflow::WritableFile> f;
576   TF_RETURN_IF_ERROR(
577       tensorflow::Env::Default()->RecursivelyCreateDir(directory_name));
578   TF_RETURN_IF_ERROR(
579       tensorflow::Env::Default()->NewWritableFile(file_name, &f));
580   TF_RETURN_IF_ERROR(f->Append(text));
581   TF_RETURN_IF_ERROR(f->Close());
582   return Status::OK();
583 }
584 
DumpIrIfEnabled(const HloModule & hlo_module,const llvm::Module & llvm_module,bool optimized)585 void DumpIrIfEnabled(const HloModule& hlo_module,
586                      const llvm::Module& llvm_module, bool optimized) {
587   const auto& debug_opts = hlo_module.config().debug_options();
588   if (!DumpingEnabledForHloModule(hlo_module)) {
589     return;
590   }
591   // We can end up compiling different modules with the same name when using
592   // XlaJitCompiledCpuFunction::Compile.  Avoid overwriting IR files previously
593   // dumped from the same process in such cases.
594   string suffix = absl::StrCat("ir-", optimized ? "with" : "no", "-opt");
595   DumpToFileInDirOrStdout(hlo_module, absl::StrCat(suffix, ".ll"),
596                           DumpModuleToString(llvm_module));
597 
598   // For some models the embedded constants can be huge, so also dump the module
599   // with the constants stripped to get IR that is easier to manipulate.  Skip
600   // this if we're dumping to stdout; there's no point in duplicating everything
601   // when writing to the terminal.
602   if (!DumpingToStdout(debug_opts)) {
603     DumpToFileInDir(hlo_module, absl::StrCat(suffix, "-noconst.ll"),
604                     DumpModuleToString(*DropConstantInitializers(llvm_module)));
605   }
606 }
607 
CreateCpuFunction(llvm::FunctionType * function_type,llvm::GlobalValue::LinkageTypes linkage,const HloModuleConfig & module_config,absl::string_view name,llvm::Module * module)608 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
609                                   llvm::GlobalValue::LinkageTypes linkage,
610                                   const HloModuleConfig& module_config,
611                                   absl::string_view name,
612                                   llvm::Module* module) {
613   llvm::Function* function =
614       llvm::Function::Create(function_type, linkage, AsStringRef(name), module);
615   function->setCallingConv(llvm::CallingConv::C);
616   function->addFnAttr("no-frame-pointer-elim", "false");
617 
618   // Generate unwind information so that GDB can crawl through the stack frames
619   // created by the JIT compiled code.
620   function->setHasUWTable();
621 
622   if (module_config.debug_options().xla_cpu_enable_fast_math()) {
623     function->addFnAttr("unsafe-fp-math", "true");
624     function->addFnAttr("no-signed-zeros-fp-math", "true");
625 
626     if (!module_config.debug_options().xla_cpu_fast_math_honor_nans()) {
627       function->addFnAttr("no-nans-fp-math", "true");
628     }
629 
630     if (!module_config.debug_options().xla_cpu_fast_math_honor_infs()) {
631       function->addFnAttr("no-infs-fp-math", "true");
632     }
633   }
634 
635   // Add the optize attribute to the function if optimizing for size. This
636   // controls internal behavior of some optimization passes (e.g. loop
637   // unrolling).
638   if (cpu::options::OptimizeForSizeRequested(module_config)) {
639     function->addFnAttr(llvm::Attribute::OptimizeForSize);
640   }
641 
642   return function;
643 }
644 
InitializeLLVMCommandLineOptions(const HloModuleConfig & config)645 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
646   auto options = config.debug_options().xla_backend_extra_options();
647   if (!options.empty()) {
648     std::vector<string> fake_argv_storage;
649     fake_argv_storage.push_back("");
650     for (const auto& it : options) {
651       // Skip options the XLA backend itself consumes.
652       if (!absl::StartsWith(it.first, "xla_")) {
653         if (it.second.empty()) {
654           fake_argv_storage.push_back(it.first);
655         } else {
656           fake_argv_storage.push_back(it.first + "=" + it.second);
657         }
658       }
659     }
660 
661     VLOG(2) << "Passing argv to LLVM:";
662     std::vector<const char*> fake_argv;
663     for (const auto& s : fake_argv_storage) {
664       fake_argv.push_back(s.c_str());
665       VLOG(2) << s;
666     }
667     llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
668   }
669 }
670 
UMulLowHigh32(llvm::IRBuilder<> * b,llvm::Value * src0,llvm::Value * src1)671 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
672                                                     llvm::Value* src0,
673                                                     llvm::Value* src1) {
674   CHECK_EQ(src0->getType()->getPrimitiveSizeInBits(), 32);
675   CHECK_EQ(src1->getType()->getPrimitiveSizeInBits(), 32);
676   llvm::Type* int64_ty = b->getInt64Ty();
677   src0 = b->CreateZExt(src0, int64_ty);
678   src1 = b->CreateZExt(src1, int64_ty);
679   return SplitInt64ToInt32s(b, b->CreateMul(src0, src1));
680 }
681 
SplitInt64ToInt32s(llvm::IRBuilder<> * b,llvm::Value * value_64bits)682 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
683     llvm::IRBuilder<>* b, llvm::Value* value_64bits) {
684   CHECK_EQ(value_64bits->getType()->getPrimitiveSizeInBits(), 64);
685   llvm::Type* int32_ty = b->getInt32Ty();
686   llvm::Value* low_32bits = b->CreateTrunc(value_64bits, int32_ty);
687   llvm::Value* high_32bits =
688       b->CreateTrunc(b->CreateLShr(value_64bits, 32), int32_ty);
689   return std::make_pair(low_32bits, high_32bits);
690 }
691 
GetOrCreateVariableForPhiloxRngState(llvm::Module * module,llvm::IRBuilder<> * b)692 llvm::GlobalVariable* GetOrCreateVariableForPhiloxRngState(
693     llvm::Module* module, llvm::IRBuilder<>* b) {
694   static const char* kPhiloxRngStateVariableName = "philox_rng_state";
695   llvm::GlobalVariable* state_ptr =
696       module->getNamedGlobal(kPhiloxRngStateVariableName);
697   if (!state_ptr) {
698     state_ptr = new llvm::GlobalVariable(
699         /*M=*/*module,
700         /*Ty=*/b->getInt64Ty(),
701         /*isConstant=*/false,
702         /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
703         /*Initializer=*/b->getInt64(0),
704         /*Name=*/kPhiloxRngStateVariableName);
705   }
706   return state_ptr;
707 }
708 
IncrementVariableForPhiloxRngState(int64 value,llvm::Module * module,llvm::IRBuilder<> * builder)709 void IncrementVariableForPhiloxRngState(int64 value, llvm::Module* module,
710                                         llvm::IRBuilder<>* builder) {
711   llvm::GlobalVariable* state_ptr =
712       GetOrCreateVariableForPhiloxRngState(module, builder);
713   llvm::Value* state_value_old = builder->CreateLoad(state_ptr, "load_state");
714   // If the 64-bit value overflows, we use the wraparound value. This should
715   // be fine in practice as we only add one to the value each time when a RNG is
716   // executed.
717   llvm::Value* state_value_new = builder->CreateAdd(
718       state_value_old, builder->getInt64(value), "inc_state");
719   builder->CreateStore(state_value_new, state_ptr);
720 }
721 
722 }  // namespace llvm_ir
723 }  // namespace xla
724