1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file contains the analysis and transformation to rewrite kernel
17 // functions such that information about alignment, aliasing and zero offsets
18 // steming from the tf_framework uses is propagated.
19 
20 #include <cstdint>
21 #include <memory>
22 
23 #include "llvm/ADT/Bitfields.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
27 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
30 #include "mlir/Support/LLVM.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
32 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
33 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
34 
35 namespace mlir {
36 namespace kernel_gen {
37 namespace transforms {
38 namespace {
39 
40 #define GEN_PASS_CLASSES
41 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
42 
43 struct PropagateTfAbiKnowledgeToKernelsPass
44     : public PropagateTfAbiKnowledgeToKernelsBase<
45           PropagateTfAbiKnowledgeToKernelsPass> {
runOnFunctionmlir::kernel_gen::transforms::__anon247e9bbd0111::PropagateTfAbiKnowledgeToKernelsPass46   void runOnFunction() override {
47     FuncOp function = getFunction();
48     llvm::SmallVector<Value, 4> worklist;
49     // We currently only handle entry functions and do not propagate across
50     // functions.
51     if (function->getAttrOfType<mlir::UnitAttr>(
52             tf_framework::TFFrameworkDialect::kTFEntryAttrName)) {
53       // For all operands of this function, we know they are aligned. Also, by
54       // construction of kernel generator, we know that there is no offset and
55       // the inner stride is one.
56       // TODO(herhut): Insert asserts in debug mode to check this.
57       for (auto argument : function.getArguments()) {
58         if (argument.getType().isa<BaseMemRefType>()) {
59           worklist.push_back(argument);
60           allocated_by_tf_runtime.insert(argument);
61           offset_is_zero.insert(argument);
62           inner_stride_is_constant.insert({argument, 1});
63         }
64       }
65     }
66 
67     // For locally allocated values, we know they are aligned and have offset
68     // zero. Further, they also do not alias with other memrefs, except in
69     // benign ways. This is by construction and ensured by the reuse analysis.
70     function.walk([&](tf_framework::TFAllocOp op) {
71       Value allocated = op.getResult();
72       worklist.push_back(allocated);
73       no_alias.insert(allocated);
74       allocated_by_tf_runtime.insert(allocated);
75       offset_is_zero.insert(allocated);
76       inner_stride_is_constant.insert({allocated, 1});
77     });
78 
79     // Next, take what we have and propagate it through known operations.
80     propagateThroughUses(worklist);
81 
82     // Now look at launches and make use of the knowledge we have.
83     function.walk([&](gpu::LaunchFuncOp launch) {
84       auto module = launch->getParentOfType<ModuleOp>();
85       auto kernel = module.lookupSymbol<LLVM::LLVMFuncOp>(launch.kernel());
86 
87       if (!kernel || kernel.isExternal()) return;
88 
89       // Count the position of kernel operands independently, as they do not
90       // coincide with laucnh operands as memref parameters get expanded when
91       // lowered to llvm.
92       int kernel_p = 0;
93       OpBuilder b = OpBuilder::atBlockBegin(&kernel.body().front());
94       llvm::SmallDenseMap<int64_t, Value> constants;
95       auto loc = kernel.getLoc();
96       for (auto operand : launch.operands()) {
97         auto memref = operand.getType().dyn_cast<MemRefType>();
98         if (!memref) {
99           // Scalar argument, advance kernel position by one.
100           kernel_p++;
101           continue;
102         }
103         if (allocated_by_tf_runtime.contains(operand)) {
104           // This was allocated by the tf runtime, so the two pointers in the
105           // descriptor coincide. Rewrite the kernel accordingly.
106           Value alloc_ptr = kernel.getArgument(kernel_p);
107           Value align_ptr = kernel.getArgument(kernel_p + 1);
108           alloc_ptr.replaceAllUsesWith(align_ptr);
109           kernel.setArgAttr(
110               kernel_p + 1, LLVM::LLVMDialect::getAlignAttrName(),
111               b.getIndexAttr(
112                   tf_framework::TFFrameworkDialect::kAllocationAlignment));
113         }
114         if (offset_is_zero.contains(operand)) {
115           Value offset = kernel.getArgument(kernel_p + 2);
116           Value &zero = constants[0];
117           if (!zero) {
118             zero = b.create<LLVM::ConstantOp>(loc, offset.getType(),
119                                               b.getIndexAttr(0));
120           }
121           offset.replaceAllUsesWith(zero);
122         }
123         auto const_stride = inner_stride_is_constant.find(operand);
124         if (const_stride != inner_stride_is_constant.end()) {
125           // The stride is the last argument belonging to this memref.
126           Value inner_stride =
127               kernel.getArgument(kernel_p + 2 + memref.getRank() * 2);
128           Value &stride_val = constants[const_stride->second];
129           if (!stride_val) {
130             stride_val = b.create<LLVM::ConstantOp>(
131                 loc, inner_stride.getType(),
132                 b.getIndexAttr(const_stride->second));
133           }
134           inner_stride.replaceAllUsesWith(stride_val);
135         }
136         if (no_alias.contains(operand)) {
137           // TODO(herhut): We also need to check whether any of the other args
138           //     are aliases. This is currently never the case by construction
139           //     but we could use the alias analysis from buffer placement here
140           //     to make sure.
141           // Add the no_alias attribute to the corresponding pointer.
142           kernel.setArgAttr(kernel_p + 1,
143                             LLVM::LLVMDialect::getNoAliasAttrName(),
144                             b.getBoolAttr(true));
145         }
146         // Advance base, aligned, offset, strides and sizes many arguments.
147         kernel_p += memref.getRank() * 2 + 3;
148       }
149     });
150   }
151 
152  private:
propagateThroughUsesmlir::kernel_gen::transforms::__anon247e9bbd0111::PropagateTfAbiKnowledgeToKernelsPass153   void propagateThroughUses(SmallVectorImpl<Value> &worklist) {
154     while (!worklist.empty()) {
155       Value candidate = worklist.pop_back_val();
156       for (auto user : candidate.getUsers()) {
157         if (isa<MemRefCastOp, MemRefReshapeOp>(user)) {
158           // Reshape and Cast propagate alignment, offset and innermost stride.
159           // TODO(herhut): This should be a trait.
160           Value result = user->getResult(0);
161           if (allocated_by_tf_runtime.contains(candidate)) {
162             allocated_by_tf_runtime.insert(result);
163           }
164           auto const_stride = inner_stride_is_constant.find(candidate);
165           if (const_stride != inner_stride_is_constant.end()) {
166             inner_stride_is_constant.insert({result, const_stride->second});
167           }
168           if (offset_is_zero.contains(candidate)) {
169             offset_is_zero.insert(result);
170           }
171           worklist.push_back(result);
172         }
173         if (auto cast = dyn_cast<MemRefReinterpretCastOp>(user)) {
174           // Check that we have offset 0.
175           Value result = cast.result();
176           if (!cast.isDynamicOffset(0) && cast.getStaticOffset(0) == 0) {
177             offset_is_zero.insert(result);
178           }
179           if (allocated_by_tf_runtime.contains(candidate)) {
180             allocated_by_tf_runtime.insert(result);
181           }
182           size_t last_stride = cast.getResultRank() - 1;
183           // TODO(herhut): Remove this once canonicalization handles this.
184           if (cast.isDynamicStride(last_stride)) {
185             auto dyn_stride = cast.getDynamicStride(last_stride)
186                                   .getDefiningOp<ConstantIndexOp>();
187             if (dyn_stride) {
188               inner_stride_is_constant.insert({result, dyn_stride.getValue()});
189             }
190           } else {
191             inner_stride_is_constant.insert(
192                 {result, cast.getStaticStride(last_stride)});
193           }
194           worklist.push_back(result);
195         }
196       }
197     }
198   }
199 
200   // Set of values that were allocated by the tf runtime and hence are aligned.
201   llvm::SmallPtrSet<Value, 8> allocated_by_tf_runtime;
202   // Set of values that are known to not have an offset of 0.
203   llvm::SmallPtrSet<Value, 8> offset_is_zero;
204   // Set of values that are known to have a constant stride.
205   llvm::SmallDenseMap<Value, int64_t, 8> inner_stride_is_constant;
206   // Set of values we know do not alias other values.
207   llvm::SmallPtrSet<Value, 8> no_alias;
208 };
209 
210 }  // namespace
211 
CreatePropagateTfAbiKnowledgeToKernels()212 std::unique_ptr<FunctionPass> CreatePropagateTfAbiKnowledgeToKernels() {
213   return std::make_unique<PropagateTfAbiKnowledgeToKernelsPass>();
214 }
215 
216 }  // namespace transforms
217 }  // namespace kernel_gen
218 }  // namespace mlir
219