1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass clusters operations according to the policy specified by the pass
17 // options. Clustered operations are placed in 'tf_device::ClusterOp'.
18 
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Block.h"  // from @llvm-project
25 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
28 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
36 #include "tensorflow/core/platform/logging.h"
37 
38 #define DEBUG_TYPE "cluster-ops-by-policy"
39 
40 namespace mlir {
41 namespace TFDevice {
42 
43 namespace {
44 
45 constexpr char kDeviceAttr[] = "device";
46 
47 // Pass definition.
48 struct ClusterOpsByPolicyPass
49     : public TF::ClusterOpsByPolicyPassBase<ClusterOpsByPolicyPass> {
50   void runOnFunction() override;
51 };
52 
53 // Returns true if `op` starts a sequence of ops that match ops in `oplist`.
54 // The found ops are written into 'matched_ops' and added to 'is_matched' set.
55 // The next matched op must be the only user of the previous matched op's
56 // result. The matched ops do not have to be consecutive. For example,
57 //    %1 = "tf.Add" %a, %b
58 //    %2 = "tf.Neg" %a
59 //    %3 = "tf.Sub" %c, %1 // the only use of %1
60 // matches "tf.Add, tf.Sub".
IsOplistMatch(Operation * op,ArrayRef<std::string> oplist,llvm::DenseSet<Operation * > & is_matched,llvm::SmallVectorImpl<Operation * > & matched_ops)61 bool IsOplistMatch(Operation *op, ArrayRef<std::string> oplist,
62                    llvm::DenseSet<Operation *> &is_matched,
63                    llvm::SmallVectorImpl<Operation *> &matched_ops) {
64   MLIRContext *ctx = op->getContext();
65 
66   // Skip 'op' if it's already part of another matched sequence of ops.
67   if (is_matched.contains(op)) return false;
68 
69   // Does this operation match first element in the oplist?
70   StringRef op_name = *oplist.begin();
71   if (op->getName().getIdentifier() != Identifier::get(op_name, ctx))
72     return false;
73 
74   matched_ops.push_back(op);
75 
76   // Check for match with the rest of oplist elements.
77   auto oplist_iter = oplist.begin() + 1;
78   auto oplist_end = oplist.end();
79   Block *block = op->getBlock();
80   auto device = op->getAttr(kDeviceAttr);
81   Operation *curr_op = op;
82 
83   while (oplist_iter != oplist_end) {
84     // Find the next op to match.
85     if (!curr_op->hasOneUse()) return false;
86     curr_op = *curr_op->getUsers().begin();
87 
88     // Skip 'op' if it's already part of another matched sequence of ops.
89     if (is_matched.contains(curr_op)) return false;
90 
91     // Check that the op matches the next op in the oplist.
92     op_name = *oplist_iter;
93     if (curr_op->getName().getIdentifier() != Identifier::get(op_name, ctx))
94       return false;
95 
96     // Don't cluster operations assigned to different devices.
97     if (curr_op->getAttr(kDeviceAttr) != device) return false;
98 
99     // Don't cluster ops across blocks.
100     if (curr_op->getBlock() != block) return false;
101 
102     // Check that op has no side effects. This guarantees that we will not
103     // reorder side-effecting ops during cluster formation.
104     if (!MemoryEffectOpInterface::hasNoEffect(curr_op)) return false;
105 
106     ++oplist_iter;
107     matched_ops.push_back(curr_op);
108   }
109 
110   is_matched.insert(matched_ops.begin(), matched_ops.end());
111 
112   return true;
113 }
114 
115 // Move matched operations into tf_device::ClusterOp.
ClusterMatchedOps(ArrayRef<Operation * > matched_ops)116 void ClusterMatchedOps(ArrayRef<Operation *> matched_ops) {
117   LLVM_DEBUG({
118     llvm::dbgs() << "Creating a cluster for matched ops:\n";
119     for (auto e : matched_ops) {
120       e->print(llvm::dbgs());
121       llvm::dbgs() << "\n";
122     }
123     llvm::dbgs() << "\n";
124   });
125 
126   // Create tf_device::ClusterOp before the last matched operation.
127   Operation *lastOp = matched_ops.back();
128   OpBuilder builder(lastOp);
129   auto loc = lastOp->getLoc();
130   auto clusterOp =
131       builder.create<tf_device::ClusterOp>(loc, lastOp->getResultTypes());
132 
133   // Create block in clusterOp's region and move 'matched_ops' into it.
134   auto block = builder.createBlock(&clusterOp.body());
135   auto block_end = block->end();
136   for (auto e : matched_ops) e->moveBefore(block, block_end);
137 
138   // Replace uses of lastOp results with uses of tf_device.cluster op.h
139   lastOp->replaceAllUsesWith(clusterOp);
140 
141   // Add 'tf_device::ReturnOp' at the end of the block.
142   builder.setInsertionPointToEnd(block);
143   builder.create<tf_device::ReturnOp>(loc, lastOp->getResults());
144 
145   // Set device attribute
146   if (auto device = lastOp->getAttr(kDeviceAttr))
147     clusterOp->setAttr(kDeviceAttr, device);
148 }
149 
150 // Define type to store list of operations.
151 typedef llvm::SmallVector<Operation *> OpList;
152 
153 // Find operations that match 'oplist' and extract them into clusters.
runOnFunction()154 void ClusterOpsByPolicyPass::runOnFunction() {
155   if (oplist.empty()) return;
156 
157   llvm::SmallVector<OpList> clusters;
158   llvm::DenseSet<Operation *> is_matched;
159 
160   // Find matching op sequences within this function.
161   getFunction().walk([&](Operation *op) {
162     llvm::SmallVector<Operation *> matched_ops;
163 
164     // Skip 'op' if it's already part of another matched sequence of ops.
165     if (is_matched.contains(op)) return;
166 
167     // Try to match 'op' to the sequence of ops in 'oplist'.
168     if (!IsOplistMatch(op, oplist, is_matched, matched_ops)) return;
169 
170     // We found a matching sequence of ops. Record it.
171     clusters.push_back(matched_ops);
172   });
173 
174   // Create clusters.
175   for (const OpList &c : clusters) ClusterMatchedOps(c);
176 }
177 
178 }  // namespace
179 
CreateClusterOpsByPolicyPass()180 std::unique_ptr<FunctionPass> CreateClusterOpsByPolicyPass() {
181   return std::make_unique<TFDevice::ClusterOpsByPolicyPass>();
182 }
183 
184 }  // namespace TFDevice
185 }  // namespace mlir
186