1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20
21 #include "llvm/ADT/EquivalenceClasses.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/utils/cycle_detector.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
25 #include "mlir/IR/MLIRContext.h" // TF:llvm-project
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h" // TF:local_config_mlir
28 #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
29
30 // This pass has similar functionality of the fusion pass in XLA stack.
31 // However, unlike XLA, it targets the fully dynamic shape scenario.
32 // Currently, it implements the kLoop and kInput fusion templates.
33 // During conversion, it tries to greedily find kLoop/kInput fusion
34 // patterns.
35 //
36 // Similar to XLA, this pass supports fusion pattern having multiple outputs
37 // if all the shape of outputs are consistent. Following are some examples.
38 //
39 // kLoop kInput
40 // +----+ +----+ +----+ +----+ +----+ +----+
41 // |elem| |elem| |elem| |elem<----+elem+---->elem+----+
42 // +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
43 // | | | | | |
44 // | | | | |
45 // +-v--+ | +-v--+ +--v---+ +--v---+ |
46 // |elem+<---+----<+elem| |reduce| |reduce| |
47 // +-+--+ +-+--+ +--+---+ +--+---+ |
48 // | | | | |
49 // | | | | |
50 // v v v v v
51 //
52 // To this end, we also add an simple shape constraint analysis phase.
53 // For kLoop fusion template, it requires all the outputs of the fused
54 // pattern have the same shape. However, we don't know the actual value
55 // of the shape at the compile time in the dynamic shape world.
56 // Fortunately, we could still infer the relationship among different ops
57 // according to their shape constrain traits. Currently, We only consider
58 // shape equality propagation for elementwise ops (assuming that implicit
59 // shape broadcast is forbidden). The above process could be built on the
60 // shape dialect once it is ready.
61
62 namespace mlir {
63 namespace mhlo {
64 namespace {
65
66 using llvm::EquivalenceClasses;
67 using FusionPattern = std::vector<Operation*>;
68 using FusionPlan = std::vector<FusionPattern>;
69
70 // To support using EquivalenceClasses for Value
71 class ValueWrapper {
72 public:
ValueWrapper(Value value)73 explicit ValueWrapper(Value value) : value_(std::move(value)) {}
74
getValue() const75 Value getValue() const { return value_; }
76
operator ==(const ValueWrapper & rhs) const77 bool operator==(const ValueWrapper& rhs) const {
78 return getValue() == rhs.getValue();
79 }
80
81 private:
82 Value value_;
83 };
84
operator <(const ValueWrapper & lhs,const ValueWrapper & rhs)85 bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
86 auto lhs_value = lhs.getValue().getAsOpaquePointer();
87 auto rhs_value = rhs.getValue().getAsOpaquePointer();
88 return lhs_value < rhs_value;
89 }
90
IsFusible(Operation * op)91 bool IsFusible(Operation* op) {
92 if (matchPattern(op, m_Constant())) {
93 return true;
94 }
95 auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
96 return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
97 op_fusibility.isFusibleWithConsumer());
98 }
99
GetInputsOfFusionPattern(const FusionPattern & pattern)100 SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
101 SmallVector<Value, 4> inputs;
102 DenseSet<Value> input_set;
103 DenseSet<Operation*> op_set;
104 for (Operation* op : pattern) {
105 bool inserted = op_set.insert(op).second;
106 (void)inserted;
107 assert(inserted && "FusionPattern contains duplicate operations");
108 }
109
110 for (Operation* op : pattern) {
111 for (Value operand : op->getOperands()) {
112 Operation* operand_op = operand.getDefiningOp();
113 if (op_set.find(operand_op) != op_set.end()) {
114 // skip if defining op is in the pattern
115 continue;
116 }
117 if (input_set.insert(operand).second) {
118 inputs.push_back(operand);
119 }
120 }
121 }
122 return inputs;
123 }
124
GetOutputsOfFusionPattern(const FusionPattern & pattern)125 SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
126 SmallVector<Value, 4> outputs;
127 DenseSet<Operation*> op_set;
128 for (Operation* op : pattern) {
129 bool inserted = op_set.insert(op).second;
130 (void)inserted;
131 assert(inserted && "FusionPattern contains duplicate operations");
132 }
133
134 for (Operation* op : pattern) {
135 for (Value result : op->getResults()) {
136 bool has_external_user = llvm::any_of(
137 result.getUses(),
138 [&](OpOperand& use) { return !op_set.count(use.getOwner()); });
139 if (has_external_user) {
140 outputs.push_back(result);
141 }
142 }
143 }
144 return outputs;
145 }
146
MergeFusionPattern(const FusionPattern & lhs,const FusionPattern & rhs)147 FusionPattern MergeFusionPattern(const FusionPattern& lhs,
148 const FusionPattern& rhs) {
149 FusionPattern pattern(lhs);
150 pattern.insert(pattern.end(), rhs.begin(), rhs.end());
151 return pattern;
152 }
153
EffectiveSize(const FusionPattern & pattern)154 inline int EffectiveSize(const FusionPattern& pattern) {
155 return llvm::count_if(
156 pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
157 }
158
159 // This is an simple shape constraint analysis, which is used to
160 // guide fusion decision (e.g. we only fuse shape-compatible ops).
161 //
162 // Currently, We only consider shape equality propagation based
163 // on the shape constrain traits of elementwise ops (assuming that
164 // implicit shape broadcast is forbidden).
165 class ShapeConstraintAnalysis {
166 public:
ShapeConstraintAnalysis(const SmallVectorImpl<Operation * > & op_list)167 explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
168 PropagateEquality(op_list);
169 }
170
171 // Returns true is `lhs` and `rhs` are supposed to have same shape.
HasSameShape(Value lhs,Value rhs)172 bool HasSameShape(Value lhs, Value rhs) {
173 return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
174 }
175
176 private:
177 // shape equality propagation based on the shape constrains of
178 // elementwise ops.
PropagateEquality(const SmallVectorImpl<Operation * > & op_list)179 void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
180 bool converged = true;
181 do {
182 converged = true;
183 auto update = [&](Value lhs, Value rhs) {
184 if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
185 converged = false;
186 impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
187 }
188 };
189 for (Operation* op : op_list) {
190 auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
191 if (!op_fusibility) continue;
192 int numInput = op->getNumOperands();
193 int numOutput = op->getNumResults();
194 // shape equality propagation between inputs.
195 for (int input1 = 0; input1 < numInput; ++input1)
196 for (int input2 = input1 + 1; input2 < numInput; ++input2)
197 if (op_fusibility.inferInputsShapeEquality(input1, input2))
198 update(op->getOperand(input1), op->getOperand(input2));
199
200 // shape equality propagation between outputs.
201 for (int output1 = 0; output1 < numOutput; ++output1)
202 for (int output2 = output1 + 1; output2 < numOutput; ++output2)
203 if (op_fusibility.inferOutputsShapeEquality(output1, output2))
204 update(op->getResult(output1), op->getResult(output2));
205
206 // shape equality propagation between input and output.
207 for (int input = 0; input < numInput; ++input)
208 for (int output = 0; output < numOutput; ++output)
209 if (op_fusibility.inferInputOutputShapeEquality(input, output))
210 update(op->getOperand(input), op->getResult(output));
211 }
212 } while (!converged);
213 }
214
215 // a UnionFind set
216 EquivalenceClasses<ValueWrapper> impl_;
217 };
218
219 // A fusion planner that can propose a fusion plan for a block of ops.
220 // The fusion plan is consisted of a group of fusion patterns.
221 //
222 // Currently all proposed patterns followed xla kLoop/kInput like fusion
223 // templates while are adapted to the fully dynamic shape world.
224 //
225 // kLoop fusion template satifies:
226 // - all ops in the fusion pattern are element-wise.
227 // - all the shapes of outputs of fusion pattern are same, and thus can
228 // fit into a same parallel loop.
229 //
230 // kInput fusion template satifies:
231 // - any op in the fusion pattern is either element-wise or a reduction.
232 // - if a op is a reduction, its output cannot be consumered by other
233 // ops in the same fusion pattern.
234 // - all the effective shapes of outputs of fusion pattern are same.
235 // - For element-wise op, its effective shape is its output shape.
236 // - For reduction op, its effective shape is its operand shape.
237 class FusionPlanner {
238 public:
FusionPlanner(const SmallVectorImpl<Operation * > & op_list)239 explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
240 : op_list_(op_list),
241 shape_analysis_(op_list),
242 cycle_detector_(op_list.size()) {
243 BuildNodeMap();
244 }
245
246 // Returns a fusion plan if success, otherwise none.
Run()247 llvm::Optional<FusionPlan> Run() {
248 // Greedily search connected fusible pattern, and ops belonging to
249 // a same fusion pattern are grouped into a cluster.
250 RunEdgeContractionLoop();
251
252 // After doing edge contraction, each unique cluster having size
253 // more than one represents a potential fusion pattern.
254 // We collect all these clusters and construct a fusion plan.
255 //
256 // Note that the ops in a fusion pattern are in topological ordering.
257 FusionPlan plan;
258 DenseMap<int, int> pattern_ids;
259 for (Operation* op : op_list_) {
260 Cluster* cluster = GetClusterForNode(op);
261 int node_id = cluster->cycles_graph_node_id();
262 if (!IsFusible(op_list_[node_id]) ||
263 EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
264 continue;
265 }
266 if (!pattern_ids.count(node_id)) {
267 int pattern_id = pattern_ids.size();
268 pattern_ids[node_id] = pattern_id;
269 plan.emplace_back();
270 }
271 plan[pattern_ids[node_id]].push_back(op);
272 }
273 return plan;
274 }
275
276 // Returns the op_list this planner operates on.
op_list() const277 const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
278
279 private:
280 // Represent a (partial) fused pattern
281 class Cluster {
282 public:
Cluster(int node_id,FusionPlanner * planner)283 Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
284 const SmallVectorImpl<Operation*>& op_list = planner->op_list();
285 pattern_.push_back(op_list[node_id]);
286 }
287
288 // Merges `other` into this cluster, and clears `other`.
Merge(Cluster * other)289 void Merge(Cluster* other) {
290 pattern_.insert(pattern_.end(), other->pattern_.begin(),
291 other->pattern_.end());
292 other->pattern_.clear();
293 }
294
295 // The number of nodes in this cluster.
cluster_size() const296 int cluster_size() const { return pattern_.size(); }
297
298 // The ID of the cluster as represented in `cycle_detector_`.
cycles_graph_node_id() const299 int cycles_graph_node_id() const { return node_id_; }
300
301 // Sets the ID of the cluster as represented in `cycle_detector_`.
set_cycles_graph_node_id(int cycles_graph_node_id)302 void set_cycles_graph_node_id(int cycles_graph_node_id) {
303 node_id_ = cycles_graph_node_id;
304 }
305
306 // Currently the fused pattern this cluster holds.
fused_pattern()307 const FusionPattern& fused_pattern() { return pattern_; }
308
309 private:
310 // ID of the representative node of this cluster.
311 int node_id_;
312
313 // the fused pattern this cluster holds.
314 FusionPattern pattern_;
315 };
316
317 private:
MakeCluster(int cycles_graph_node_id)318 Cluster* MakeCluster(int cycles_graph_node_id) {
319 cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
320 return cluster_storage_.back().get();
321 }
322
BuildNodeMap()323 void BuildNodeMap() {
324 int num_nodes = op_list_.size();
325 for (int node_id = 0; node_id < num_nodes; ++node_id) {
326 Operation* op = op_list_[node_id];
327 MakeCluster(node_id);
328 op_to_node_id_[op] = node_id;
329 leader_for_node_.insert(node_id);
330 for (Value operand : op->getOperands()) {
331 Operation* operand_op = operand.getDefiningOp();
332 if (operand_op == nullptr) {
333 // skip block argument
334 continue;
335 }
336 auto iter = op_to_node_id_.find(operand_op);
337 assert(iter != op_to_node_id_.end());
338 cycle_detector_.InsertEdge(iter->second, node_id);
339 }
340 }
341 }
342
343 // Returns the cluster contains this op.
GetClusterForNode(Operation * n)344 Cluster* GetClusterForNode(Operation* n) {
345 int id = op_to_node_id_.at(n);
346 id = leader_for_node_.getLeaderValue(id);
347 return cluster_storage_[id].get();
348 }
349
350 // Returns the cluster contains the op having `node_id`.
GetClusterForCyclesGraphNode(int node_id)351 Cluster* GetClusterForCyclesGraphNode(int node_id) {
352 return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
353 }
354
355 // Merges the clusters `cluster_from` and `cluster_to`.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)356 bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
357 int from = cluster_from->cycles_graph_node_id();
358 int to = cluster_to->cycles_graph_node_id();
359
360 auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
361 if (!optional_merged_node.hasValue()) {
362 llvm::dbgs() << "Could not contract " << from << " -> " << to
363 << " because contracting the edge would create a cycle.";
364 return false;
365 }
366
367 // Merge the clusters.
368 cluster_from->Merge(cluster_to);
369 cluster_from->set_cycles_graph_node_id(*optional_merged_node);
370
371 // Merge the UnionFind Set.
372 leader_for_node_.unionSets(from, to);
373 return true;
374 }
375
376 template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)377 bool ForEachEdgeInPostOrder(FnTy fn) {
378 bool changed = false;
379 for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
380 Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
381 // Make a copy of the set of successors because we may modify the graph in
382 // TryToContractEdge.
383 std::vector<int32_t> successors_copy =
384 cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
385
386 for (int to : successors_copy) {
387 Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
388 bool contracted_edge = fn(cluster_from, cluster_to);
389 changed |= contracted_edge;
390 }
391 }
392
393 return changed;
394 }
395
396 // returns the outputs if two cluster were merged
GetResultsOfFusedPattern(Cluster * from,Cluster * to)397 SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
398 FusionPattern fused_pattern =
399 MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
400 return GetOutputsOfFusionPattern(fused_pattern);
401 }
402
403 // This function check if fusing `from` with `to` is valid and if so perform
404 // the merge. The validity is based on the operations in the clusters and
405 // the compatibility of the shapes of the outputs of the would-be fused
406 // clusters.
407 // Returns true is the merge was performed.
TryToContractEdge(Cluster * from,Cluster * to)408 bool TryToContractEdge(Cluster* from, Cluster* to) {
409 int node_to = to->cycles_graph_node_id();
410 int node_from = from->cycles_graph_node_id();
411
412 // Both node_to and node_from should be fusible
413 if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
414 return false;
415 }
416
417 auto op_from_fusibility =
418 dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
419 if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
420 // This op cannot be fused with its consumers.
421 return false;
422 }
423
424 auto op_to_fusibility =
425 dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
426 if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
427 // This op cannot be fused with its operands.
428 return false;
429 }
430
431 // Output shapes of a fusion pattern should be compatible as described in
432 // the document of this class.
433 SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
434 auto get_workload_shape = [](Value v) {
435 Operation* op = v.getDefiningOp();
436 // Block argument
437 if (!op) return v;
438 auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
439 // Const value
440 if (!op_fusibility) return v;
441 llvm::Optional<Value> workload =
442 op_fusibility.inferEffectiveWorkloadShape();
443 return workload.hasValue() ? *workload : v;
444 };
445
446 Value ref = get_workload_shape(results[0]);
447 if (!llvm::all_of(results, [&](Value result) {
448 Value val = get_workload_shape(result);
449 return shape_analysis_.HasSameShape(ref, val);
450 })) {
451 return false;
452 }
453
454 return MergeClusters(from, to);
455 }
456
457 // Greedily fuse connected node.
RunEdgeContractionLoop()458 bool RunEdgeContractionLoop() {
459 using std::placeholders::_1;
460 using std::placeholders::_2;
461 return ForEachEdgeInPostOrder(
462 std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
463 }
464
465 const SmallVectorImpl<Operation*>& op_list_;
466
467 // Shape equality checker
468 ShapeConstraintAnalysis shape_analysis_;
469
470 // op -> node_id
471 std::unordered_map<Operation*, int> op_to_node_id_;
472
473 // make sure not introduce cycle after fusion
474 GraphCycles cycle_detector_;
475 std::vector<std::unique_ptr<Cluster>> cluster_storage_;
476
477 // a UnionFind set. Each set represents a (partial) fused pattern
478 // and has a leader as representation.
479 EquivalenceClasses<int32_t> leader_for_node_;
480 };
481
482 struct MhloFusionPass : public mlir::PassWrapper<MhloFusionPass, FunctionPass> {
runOnFunctionmlir::mhlo::__anon3c5f23730111::MhloFusionPass483 void runOnFunction() override {
484 FuncOp func = getFunction();
485 if (!IsTargetFunc(func)) {
486 return;
487 }
488
489 // process each block and do fusion within a block.
490 for (Block& block : func) {
491 SmallVector<Operation*, 4> op_list;
492 for (Operation& op : block) {
493 op_list.push_back(&op);
494 }
495
496 FusionPlanner planner(op_list);
497 llvm::Optional<FusionPlan> plan = planner.Run();
498 if (!plan) {
499 emitError(func.getLoc(), "can't find a fusion plan");
500 signalPassFailure();
501 return;
502 }
503 if (!ApplyFusionPlan(*plan)) {
504 emitError(func.getLoc(), "apply fusion plan failed");
505 signalPassFailure();
506 return;
507 }
508 }
509 }
510
IsTargetFuncmlir::mhlo::__anon3c5f23730111::MhloFusionPass511 bool IsTargetFunc(FuncOp func) {
512 int num_fusible_ops = 0;
513 bool is_target_func = false;
514 // We only process the function having enough candidates
515 func.walk([&](Operation* op) {
516 num_fusible_ops +=
517 static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
518 is_target_func = (num_fusible_ops > 1);
519 // early stop
520 if (is_target_func) return WalkResult::interrupt();
521 return WalkResult::advance();
522 });
523 return is_target_func;
524 }
525
ApplyFusionPlanmlir::mhlo::__anon3c5f23730111::MhloFusionPass526 bool ApplyFusionPlan(const FusionPlan& plan) {
527 for (const FusionPattern& pattern : plan) {
528 OpBuilder b(pattern.back());
529
530 SmallVector<Location, 4> locations;
531 locations.reserve(pattern.size());
532 for (Operation* op : pattern) {
533 locations.push_back(op->getLoc());
534 }
535 Location fused_loc =
536 FusedLoc::get(locations, pattern.back()->getContext());
537
538 SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
539 SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
540 SmallVector<Type, 4> output_types;
541 output_types.reserve(outputs.size());
542 for (Value v : outputs) {
543 output_types.push_back(v.getType());
544 }
545
546 FusionOp fusion =
547 b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
548 Region& region = fusion.fused_computation();
549 region.push_back(new Block);
550 Block& block = region.front();
551 for (Operation* op : pattern) {
552 op->moveBefore(&block, block.end());
553 }
554 b.setInsertionPoint(&block, block.end());
555 b.create<mhlo::ReturnOp>(fused_loc, outputs);
556
557 for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
558 Value output = std::get<0>(output_and_result);
559 Value fusion_result = std::get<1>(output_and_result);
560 for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
561 if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
562 }
563 }
564 }
565 return true;
566 }
567 };
568
569 } // namespace
570
createMhloFusionPass()571 std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass() {
572 return std::make_unique<MhloFusionPass>();
573 }
574
575 } // namespace mhlo
576 } // namespace mlir
577