1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass clusters operations according to the policy specified by the pass
17 // options. Clustered operations are placed in 'tf_device::ClusterOp'.
18
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Block.h" // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28 #include "mlir/IR/MLIRContext.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
36 #include "tensorflow/core/platform/logging.h"
37
38 #define DEBUG_TYPE "cluster-ops-by-policy"
39
40 namespace mlir {
41 namespace TFDevice {
42
43 namespace {
44
45 constexpr char kDeviceAttr[] = "device";
46
47 // Pass definition.
48 struct ClusterOpsByPolicyPass
49 : public TF::ClusterOpsByPolicyPassBase<ClusterOpsByPolicyPass> {
50 void runOnFunction() override;
51 };
52
53 // Returns true if `op` starts a sequence of ops that match ops in `oplist`.
54 // The found ops are written into 'matched_ops' and added to 'is_matched' set.
55 // The next matched op must be the only user of the previous matched op's
56 // result. The matched ops do not have to be consecutive. For example,
57 // %1 = "tf.Add" %a, %b
58 // %2 = "tf.Neg" %a
59 // %3 = "tf.Sub" %c, %1 // the only use of %1
60 // matches "tf.Add, tf.Sub".
IsOplistMatch(Operation * op,ArrayRef<std::string> oplist,llvm::DenseSet<Operation * > & is_matched,llvm::SmallVectorImpl<Operation * > & matched_ops)61 bool IsOplistMatch(Operation *op, ArrayRef<std::string> oplist,
62 llvm::DenseSet<Operation *> &is_matched,
63 llvm::SmallVectorImpl<Operation *> &matched_ops) {
64 MLIRContext *ctx = op->getContext();
65
66 // Skip 'op' if it's already part of another matched sequence of ops.
67 if (is_matched.contains(op)) return false;
68
69 // Does this operation match first element in the oplist?
70 StringRef op_name = *oplist.begin();
71 if (op->getName().getIdentifier() != Identifier::get(op_name, ctx))
72 return false;
73
74 matched_ops.push_back(op);
75
76 // Check for match with the rest of oplist elements.
77 auto oplist_iter = oplist.begin() + 1;
78 auto oplist_end = oplist.end();
79 Block *block = op->getBlock();
80 auto device = op->getAttr(kDeviceAttr);
81 Operation *curr_op = op;
82
83 while (oplist_iter != oplist_end) {
84 // Find the next op to match.
85 if (!curr_op->hasOneUse()) return false;
86 curr_op = *curr_op->getUsers().begin();
87
88 // Skip 'op' if it's already part of another matched sequence of ops.
89 if (is_matched.contains(curr_op)) return false;
90
91 // Check that the op matches the next op in the oplist.
92 op_name = *oplist_iter;
93 if (curr_op->getName().getIdentifier() != Identifier::get(op_name, ctx))
94 return false;
95
96 // Don't cluster operations assigned to different devices.
97 if (curr_op->getAttr(kDeviceAttr) != device) return false;
98
99 // Don't cluster ops across blocks.
100 if (curr_op->getBlock() != block) return false;
101
102 // Check that op has no side effects. This guarantees that we will not
103 // reorder side-effecting ops during cluster formation.
104 if (!MemoryEffectOpInterface::hasNoEffect(curr_op)) return false;
105
106 ++oplist_iter;
107 matched_ops.push_back(curr_op);
108 }
109
110 is_matched.insert(matched_ops.begin(), matched_ops.end());
111
112 return true;
113 }
114
115 // Move matched operations into tf_device::ClusterOp.
ClusterMatchedOps(ArrayRef<Operation * > matched_ops)116 void ClusterMatchedOps(ArrayRef<Operation *> matched_ops) {
117 LLVM_DEBUG({
118 llvm::dbgs() << "Creating a cluster for matched ops:\n";
119 for (auto e : matched_ops) {
120 e->print(llvm::dbgs());
121 llvm::dbgs() << "\n";
122 }
123 llvm::dbgs() << "\n";
124 });
125
126 // Create tf_device::ClusterOp before the last matched operation.
127 Operation *lastOp = matched_ops.back();
128 OpBuilder builder(lastOp);
129 auto loc = lastOp->getLoc();
130 auto clusterOp =
131 builder.create<tf_device::ClusterOp>(loc, lastOp->getResultTypes());
132
133 // Create block in clusterOp's region and move 'matched_ops' into it.
134 auto block = builder.createBlock(&clusterOp.body());
135 auto block_end = block->end();
136 for (auto e : matched_ops) e->moveBefore(block, block_end);
137
138 // Replace uses of lastOp results with uses of tf_device.cluster op.h
139 lastOp->replaceAllUsesWith(clusterOp);
140
141 // Add 'tf_device::ReturnOp' at the end of the block.
142 builder.setInsertionPointToEnd(block);
143 builder.create<tf_device::ReturnOp>(loc, lastOp->getResults());
144
145 // Set device attribute
146 if (auto device = lastOp->getAttr(kDeviceAttr))
147 clusterOp->setAttr(kDeviceAttr, device);
148 }
149
150 // Define type to store list of operations.
151 typedef llvm::SmallVector<Operation *> OpList;
152
153 // Find operations that match 'oplist' and extract them into clusters.
runOnFunction()154 void ClusterOpsByPolicyPass::runOnFunction() {
155 if (oplist.empty()) return;
156
157 llvm::SmallVector<OpList> clusters;
158 llvm::DenseSet<Operation *> is_matched;
159
160 // Find matching op sequences within this function.
161 getFunction().walk([&](Operation *op) {
162 llvm::SmallVector<Operation *> matched_ops;
163
164 // Skip 'op' if it's already part of another matched sequence of ops.
165 if (is_matched.contains(op)) return;
166
167 // Try to match 'op' to the sequence of ops in 'oplist'.
168 if (!IsOplistMatch(op, oplist, is_matched, matched_ops)) return;
169
170 // We found a matching sequence of ops. Record it.
171 clusters.push_back(matched_ops);
172 });
173
174 // Create clusters.
175 for (const OpList &c : clusters) ClusterMatchedOps(c);
176 }
177
178 } // namespace
179
CreateClusterOpsByPolicyPass()180 std::unique_ptr<FunctionPass> CreateClusterOpsByPolicyPass() {
181 return std::make_unique<TFDevice::ClusterOpsByPolicyPass>();
182 }
183
184 } // namespace TFDevice
185 } // namespace mlir
186