1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
17 
18 #include "llvm/Linker/Linker.h"
19 #include "llvm/Transforms/IPO/Internalize.h"
20 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"  // from @llvm-project
21 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"  // from @llvm-project
22 #include "mlir/Dialect/Linalg/Passes.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Pass/PassManager.h"  // from @llvm-project
26 #include "mlir/Target/LLVMIR.h"  // from @llvm-project
27 #include "mlir/Target/LLVMIR/Export.h"  // from @llvm-project
28 #include "mlir/Transforms/Passes.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
30 
31 namespace xla {
32 namespace cpu {
33 namespace {
34 
35 // Lower an MLIR module to an LLVM module.
MakeLLVMModule(mlir::OwningModuleRef module,llvm::LLVMContext * context)36 std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module,
37                                              llvm::LLVMContext *context) {
38   // When set, the LLVM backend will be allowed to reassociate floating-point
39   // reductions, which enables much more efficient "horizontal" SIMD
40   // implementations.
41   // TODO(kramerb): link this to the right option, command line flag, etc.
42   constexpr bool kReassociateFPReductions = true;
43 
44   mlir::PassManager manager(module->getContext(),
45                             mlir::OpPassManager::Nesting::Implicit);
46   manager.addPass(mlir::createConvertLinalgToLoopsPass());
47   manager.addPass(mlir::createLowerAffinePass());
48   manager.addPass(mlir::createLowerToCFGPass());
49   manager.addPass(mlir::createConvertVectorToLLVMPass(
50       mlir::LowerVectorToLLVMOptions().setReassociateFPReductions(
51           kReassociateFPReductions)));
52   CHECK(succeeded(manager.run(*module)));
53   return mlir::translateModuleToLLVMIR(*module, *context);
54 }
55 
56 // Get arguments to pass a memref to an mlir function.
BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value * > * args,llvm::IRBuilder<> * b,const Shape & opShape,llvm::Value * op_val)57 void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args,
58                         llvm::IRBuilder<> *b, const Shape &opShape,
59                         llvm::Value *op_val) {
60   llvm::Type *ty = op_val->getType();
61   while (auto aty = llvm::dyn_cast<llvm::ArrayType>(
62              llvm::cast<llvm::PointerType>(ty)->getElementType())) {
63     ty = aty->getElementType()->getPointerTo();
64   }
65   op_val = b->CreateBitCast(op_val, ty);
66 
67   args->push_back(op_val);          // Allocated pointer.
68   args->push_back(op_val);          // Aligned pointer.
69   args->push_back(b->getInt64(0));  // Offset.
70 
71   // Sizes.
72   for (int64 dim : opShape.dimensions()) {
73     args->push_back(b->getInt64(dim));
74   }
75 
76   int64_t accumulated_stride = 1;
77   llvm::SmallVector<int64_t, 4> strides(opShape.rank(), 1);
78   for (int64 dim : LayoutUtil::MinorToMajor(opShape)) {
79     strides[dim] = accumulated_stride;
80     accumulated_stride *= opShape.dimensions(dim);
81   }
82 
83   // Strides.
84   for (int64 stride : strides) {
85     args->push_back(b->getInt64(stride));
86   }
87 }
88 }  // namespace
89 
EmitMlirFuncAndCall(mlir::MLIRContext * context,llvm::IRBuilder<> * b,const Shape & result_shape,llvm::ArrayRef<Shape> operand_shapes,llvm::Value * result_ptr,llvm::ArrayRef<llvm::Value * > operand_ptrs,llvm::StringRef func_name,llvm::function_ref<void (mlir::OpBuilder *,mlir::FuncOp)> emitter)90 Status EmitMlirFuncAndCall(
91     mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
92     llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
93     llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
94     llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter) {
95   llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent();
96   mlir::Builder mlir_builder(context);
97 
98   // Get memref types for the inputs and output.
99   TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType(
100                                                  result_shape, mlir_builder));
101   std::vector<mlir::Type> operand_types = {ret_memref};
102   for (int i = 0; i != operand_shapes.size(); ++i) {
103     TF_ASSIGN_OR_RETURN(
104         mlir::Type op_memref,
105         ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder));
106     operand_types.push_back(op_memref);
107   }
108 
109   // Create the function an call the emission callback.
110   mlir::Location loc = mlir::UnknownLoc::get(context);
111   auto function = mlir::FuncOp::create(
112       loc, func_name, mlir::FunctionType::get(context, operand_types, {}));
113   function.addEntryBlock();
114   mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
115   mlir_module->push_back(function);
116   mlir::OpBuilder op_builder(&function.getBody());
117   emitter(&op_builder, function);
118 
119   // Now link it all into the main LLVM module.
120   auto mlir_llvm_module =
121       MakeLLVMModule(std::move(mlir_module), &b->getContext());
122   mlir_llvm_module->setDataLayout(llvm_module->getDataLayout());
123   llvm::Linker::linkModules(
124       *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None,
125       [](llvm::Module &M, const llvm::StringSet<> &GVS) {
126         llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
127           return !GV.hasName() || (GVS.count(GV.getName()) == 0);
128         });
129       });
130 
131   // And leave behind a call to the function generated by MLIR.
132   llvm::Function *func = llvm_module->getFunction(func_name);
133   llvm::SmallVector<llvm::Value *, 4> op_vals;
134   BuildViewForBuffer(&op_vals, b, result_shape, result_ptr);
135   for (int i = 0; i != operand_shapes.size(); ++i) {
136     BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]);
137   }
138   b->CreateCall(func, op_vals);
139 
140   return Status::OK();
141 }
142 
143 }  // namespace cpu
144 }  // namespace xla
145