1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
17
18 #include "llvm/Linker/Linker.h"
19 #include "llvm/Transforms/IPO/Internalize.h"
20 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
21 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
22 #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Pass/PassManager.h" // from @llvm-project
26 #include "mlir/Target/LLVMIR.h" // from @llvm-project
27 #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
28 #include "mlir/Transforms/Passes.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
30
31 namespace xla {
32 namespace cpu {
33 namespace {
34
35 // Lower an MLIR module to an LLVM module.
MakeLLVMModule(mlir::OwningModuleRef module,llvm::LLVMContext * context)36 std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module,
37 llvm::LLVMContext *context) {
38 // When set, the LLVM backend will be allowed to reassociate floating-point
39 // reductions, which enables much more efficient "horizontal" SIMD
40 // implementations.
41 // TODO(kramerb): link this to the right option, command line flag, etc.
42 constexpr bool kReassociateFPReductions = true;
43
44 mlir::PassManager manager(module->getContext(),
45 mlir::OpPassManager::Nesting::Implicit);
46 manager.addPass(mlir::createConvertLinalgToLoopsPass());
47 manager.addPass(mlir::createLowerAffinePass());
48 manager.addPass(mlir::createLowerToCFGPass());
49 manager.addPass(mlir::createConvertVectorToLLVMPass(
50 mlir::LowerVectorToLLVMOptions().setReassociateFPReductions(
51 kReassociateFPReductions)));
52 CHECK(succeeded(manager.run(*module)));
53 return mlir::translateModuleToLLVMIR(*module, *context);
54 }
55
56 // Get arguments to pass a memref to an mlir function.
BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value * > * args,llvm::IRBuilder<> * b,const Shape & opShape,llvm::Value * op_val)57 void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args,
58 llvm::IRBuilder<> *b, const Shape &opShape,
59 llvm::Value *op_val) {
60 llvm::Type *ty = op_val->getType();
61 while (auto aty = llvm::dyn_cast<llvm::ArrayType>(
62 llvm::cast<llvm::PointerType>(ty)->getElementType())) {
63 ty = aty->getElementType()->getPointerTo();
64 }
65 op_val = b->CreateBitCast(op_val, ty);
66
67 args->push_back(op_val); // Allocated pointer.
68 args->push_back(op_val); // Aligned pointer.
69 args->push_back(b->getInt64(0)); // Offset.
70
71 // Sizes.
72 for (int64 dim : opShape.dimensions()) {
73 args->push_back(b->getInt64(dim));
74 }
75
76 int64_t accumulated_stride = 1;
77 llvm::SmallVector<int64_t, 4> strides(opShape.rank(), 1);
78 for (int64 dim : LayoutUtil::MinorToMajor(opShape)) {
79 strides[dim] = accumulated_stride;
80 accumulated_stride *= opShape.dimensions(dim);
81 }
82
83 // Strides.
84 for (int64 stride : strides) {
85 args->push_back(b->getInt64(stride));
86 }
87 }
88 } // namespace
89
EmitMlirFuncAndCall(mlir::MLIRContext * context,llvm::IRBuilder<> * b,const Shape & result_shape,llvm::ArrayRef<Shape> operand_shapes,llvm::Value * result_ptr,llvm::ArrayRef<llvm::Value * > operand_ptrs,llvm::StringRef func_name,llvm::function_ref<void (mlir::OpBuilder *,mlir::FuncOp)> emitter)90 Status EmitMlirFuncAndCall(
91 mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
92 llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
93 llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
94 llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter) {
95 llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent();
96 mlir::Builder mlir_builder(context);
97
98 // Get memref types for the inputs and output.
99 TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType(
100 result_shape, mlir_builder));
101 std::vector<mlir::Type> operand_types = {ret_memref};
102 for (int i = 0; i != operand_shapes.size(); ++i) {
103 TF_ASSIGN_OR_RETURN(
104 mlir::Type op_memref,
105 ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder));
106 operand_types.push_back(op_memref);
107 }
108
109 // Create the function an call the emission callback.
110 mlir::Location loc = mlir::UnknownLoc::get(context);
111 auto function = mlir::FuncOp::create(
112 loc, func_name, mlir::FunctionType::get(context, operand_types, {}));
113 function.addEntryBlock();
114 mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
115 mlir_module->push_back(function);
116 mlir::OpBuilder op_builder(&function.getBody());
117 emitter(&op_builder, function);
118
119 // Now link it all into the main LLVM module.
120 auto mlir_llvm_module =
121 MakeLLVMModule(std::move(mlir_module), &b->getContext());
122 mlir_llvm_module->setDataLayout(llvm_module->getDataLayout());
123 llvm::Linker::linkModules(
124 *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None,
125 [](llvm::Module &M, const llvm::StringSet<> &GVS) {
126 llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
127 return !GV.hasName() || (GVS.count(GV.getName()) == 0);
128 });
129 });
130
131 // And leave behind a call to the function generated by MLIR.
132 llvm::Function *func = llvm_module->getFunction(func_name);
133 llvm::SmallVector<llvm::Value *, 4> op_vals;
134 BuildViewForBuffer(&op_vals, b, result_shape, result_ptr);
135 for (int i = 0; i != operand_shapes.size(); ++i) {
136 BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]);
137 }
138 b->CreateCall(func, op_vals);
139
140 return Status::OK();
141 }
142
143 } // namespace cpu
144 } // namespace xla
145