1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "llvm/ADT/EquivalenceClasses.h"
22 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
23 #include "mlir-hlo/utils/cycle_detector.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
25 #include "mlir/IR/MLIRContext.h"              // TF:llvm-project
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h"               // TF:local_config_mlir
28 #include "mlir/Transforms/RegionUtils.h"  // TF:llvm-project
29 
30 // This pass has similar functionality of the fusion pass in XLA stack.
31 // However, unlike XLA, it targets the fully dynamic shape scenario.
32 // Currently, it implements the kLoop and kInput fusion templates.
33 // During conversion, it tries to greedily find kLoop/kInput fusion
34 // patterns.
35 //
36 // Similar to XLA, this pass supports fusion pattern having multiple outputs
37 // if all the shape of outputs are consistent. Following are some examples.
38 //
39 //        kLoop                          kInput
40 // +----+  +----+  +----+    +----+    +----+    +----+
41 // |elem|  |elem|  |elem|    |elem<----+elem+---->elem+----+
42 // +-+--+  +-+--+  +-+--+    +-+--+    +----+    +-+--+    |
43 //   |       |       |         |                   |       |
44 //   |               |         |                   |       |
45 // +-v--+    |     +-v--+   +--v---+            +--v---+   |
46 // |elem+<---+----<+elem|   |reduce|            |reduce|   |
47 // +-+--+          +-+--+   +--+---+            +--+---+   |
48 //   |               |         |                   |       |
49 //   |               |         |                   |       |
50 //   v               v         v                   v       v
51 //
52 // To this end, we also add an simple shape constraint analysis phase.
53 // For kLoop fusion template, it requires all the outputs of the fused
54 // pattern have the same shape. However, we don't know the actual value
55 // of the shape at the compile time in the dynamic shape world.
56 // Fortunately, we could still infer the relationship among different ops
57 // according to their shape constrain traits. Currently, We only consider
58 // shape equality propagation for elementwise ops (assuming that implicit
59 // shape broadcast is forbidden). The above process could be built on the
60 // shape dialect once it is ready.
61 
62 namespace mlir {
63 namespace mhlo {
64 namespace {
65 
66 using llvm::EquivalenceClasses;
67 using FusionPattern = std::vector<Operation*>;
68 using FusionPlan = std::vector<FusionPattern>;
69 
70 // To support using EquivalenceClasses for Value
71 class ValueWrapper {
72  public:
ValueWrapper(Value value)73   explicit ValueWrapper(Value value) : value_(std::move(value)) {}
74 
getValue() const75   Value getValue() const { return value_; }
76 
operator ==(const ValueWrapper & rhs) const77   bool operator==(const ValueWrapper& rhs) const {
78     return getValue() == rhs.getValue();
79   }
80 
81  private:
82   Value value_;
83 };
84 
operator <(const ValueWrapper & lhs,const ValueWrapper & rhs)85 bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
86   auto lhs_value = lhs.getValue().getAsOpaquePointer();
87   auto rhs_value = rhs.getValue().getAsOpaquePointer();
88   return lhs_value < rhs_value;
89 }
90 
IsFusible(Operation * op)91 bool IsFusible(Operation* op) {
92   if (matchPattern(op, m_Constant())) {
93     return true;
94   }
95   auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
96   return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
97                            op_fusibility.isFusibleWithConsumer());
98 }
99 
GetInputsOfFusionPattern(const FusionPattern & pattern)100 SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
101   SmallVector<Value, 4> inputs;
102   DenseSet<Value> input_set;
103   DenseSet<Operation*> op_set;
104   for (Operation* op : pattern) {
105     bool inserted = op_set.insert(op).second;
106     (void)inserted;
107     assert(inserted && "FusionPattern contains duplicate operations");
108   }
109 
110   for (Operation* op : pattern) {
111     for (Value operand : op->getOperands()) {
112       Operation* operand_op = operand.getDefiningOp();
113       if (op_set.find(operand_op) != op_set.end()) {
114         // skip if defining op is in the pattern
115         continue;
116       }
117       if (input_set.insert(operand).second) {
118         inputs.push_back(operand);
119       }
120     }
121   }
122   return inputs;
123 }
124 
GetOutputsOfFusionPattern(const FusionPattern & pattern)125 SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
126   SmallVector<Value, 4> outputs;
127   DenseSet<Operation*> op_set;
128   for (Operation* op : pattern) {
129     bool inserted = op_set.insert(op).second;
130     (void)inserted;
131     assert(inserted && "FusionPattern contains duplicate operations");
132   }
133 
134   for (Operation* op : pattern) {
135     for (Value result : op->getResults()) {
136       bool has_external_user = llvm::any_of(
137           result.getUses(),
138           [&](OpOperand& use) { return !op_set.count(use.getOwner()); });
139       if (has_external_user) {
140         outputs.push_back(result);
141       }
142     }
143   }
144   return outputs;
145 }
146 
MergeFusionPattern(const FusionPattern & lhs,const FusionPattern & rhs)147 FusionPattern MergeFusionPattern(const FusionPattern& lhs,
148                                  const FusionPattern& rhs) {
149   FusionPattern pattern(lhs);
150   pattern.insert(pattern.end(), rhs.begin(), rhs.end());
151   return pattern;
152 }
153 
EffectiveSize(const FusionPattern & pattern)154 inline int EffectiveSize(const FusionPattern& pattern) {
155   return llvm::count_if(
156       pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
157 }
158 
159 // This is an simple shape constraint analysis, which is used to
160 // guide fusion decision (e.g. we only fuse shape-compatible ops).
161 //
162 // Currently, We only consider shape equality propagation based
163 // on the shape constrain traits of elementwise ops (assuming that
164 // implicit shape broadcast is forbidden).
165 class ShapeConstraintAnalysis {
166  public:
ShapeConstraintAnalysis(const SmallVectorImpl<Operation * > & op_list)167   explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
168     PropagateEquality(op_list);
169   }
170 
171   // Returns true is `lhs` and `rhs` are supposed to have same shape.
HasSameShape(Value lhs,Value rhs)172   bool HasSameShape(Value lhs, Value rhs) {
173     return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
174   }
175 
176  private:
177   // shape equality propagation based on the shape constrains of
178   // elementwise ops.
PropagateEquality(const SmallVectorImpl<Operation * > & op_list)179   void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
180     bool converged = true;
181     do {
182       converged = true;
183       auto update = [&](Value lhs, Value rhs) {
184         if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
185           converged = false;
186           impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
187         }
188       };
189       for (Operation* op : op_list) {
190         auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
191         if (!op_fusibility) continue;
192         int numInput = op->getNumOperands();
193         int numOutput = op->getNumResults();
194         // shape equality propagation between inputs.
195         for (int input1 = 0; input1 < numInput; ++input1)
196           for (int input2 = input1 + 1; input2 < numInput; ++input2)
197             if (op_fusibility.inferInputsShapeEquality(input1, input2))
198               update(op->getOperand(input1), op->getOperand(input2));
199 
200         // shape equality propagation between outputs.
201         for (int output1 = 0; output1 < numOutput; ++output1)
202           for (int output2 = output1 + 1; output2 < numOutput; ++output2)
203             if (op_fusibility.inferOutputsShapeEquality(output1, output2))
204               update(op->getResult(output1), op->getResult(output2));
205 
206         // shape equality propagation between input and output.
207         for (int input = 0; input < numInput; ++input)
208           for (int output = 0; output < numOutput; ++output)
209             if (op_fusibility.inferInputOutputShapeEquality(input, output))
210               update(op->getOperand(input), op->getResult(output));
211       }
212     } while (!converged);
213   }
214 
215   // a UnionFind set
216   EquivalenceClasses<ValueWrapper> impl_;
217 };
218 
219 // A fusion planner that can propose a fusion plan for a block of ops.
220 // The fusion plan is consisted of a group of fusion patterns.
221 //
222 // Currently all proposed patterns followed xla kLoop/kInput like fusion
223 // templates while are adapted to the fully dynamic shape world.
224 //
225 // kLoop fusion template satifies:
226 //   - all ops in the fusion pattern are element-wise.
227 //   - all the shapes of outputs of fusion pattern are same, and thus can
228 //     fit into a same parallel loop.
229 //
230 // kInput fusion template satifies:
231 //   - any op in the fusion pattern is either element-wise or a reduction.
232 //   - if a op is a reduction, its output cannot be consumered by other
233 //     ops in the same fusion pattern.
234 //   - all the effective shapes of outputs of fusion pattern are same.
235 //     - For element-wise op, its effective shape is its output shape.
236 //     - For reduction op, its effective shape is its operand shape.
237 class FusionPlanner {
238  public:
FusionPlanner(const SmallVectorImpl<Operation * > & op_list)239   explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
240       : op_list_(op_list),
241         shape_analysis_(op_list),
242         cycle_detector_(op_list.size()) {
243     BuildNodeMap();
244   }
245 
246   // Returns a fusion plan if success, otherwise none.
Run()247   llvm::Optional<FusionPlan> Run() {
248     // Greedily search connected fusible pattern, and ops belonging to
249     // a same fusion pattern are grouped into a cluster.
250     RunEdgeContractionLoop();
251 
252     // After doing edge contraction, each unique cluster having size
253     // more than one represents a potential fusion pattern.
254     // We collect all these clusters and construct a fusion plan.
255     //
256     // Note that the ops in a fusion pattern are in topological ordering.
257     FusionPlan plan;
258     DenseMap<int, int> pattern_ids;
259     for (Operation* op : op_list_) {
260       Cluster* cluster = GetClusterForNode(op);
261       int node_id = cluster->cycles_graph_node_id();
262       if (!IsFusible(op_list_[node_id]) ||
263           EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
264         continue;
265       }
266       if (!pattern_ids.count(node_id)) {
267         int pattern_id = pattern_ids.size();
268         pattern_ids[node_id] = pattern_id;
269         plan.emplace_back();
270       }
271       plan[pattern_ids[node_id]].push_back(op);
272     }
273     return plan;
274   }
275 
276   // Returns the op_list this planner operates on.
op_list() const277   const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
278 
279  private:
280   // Represent a (partial) fused pattern
281   class Cluster {
282    public:
Cluster(int node_id,FusionPlanner * planner)283     Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
284       const SmallVectorImpl<Operation*>& op_list = planner->op_list();
285       pattern_.push_back(op_list[node_id]);
286     }
287 
288     // Merges `other` into this cluster, and clears `other`.
Merge(Cluster * other)289     void Merge(Cluster* other) {
290       pattern_.insert(pattern_.end(), other->pattern_.begin(),
291                       other->pattern_.end());
292       other->pattern_.clear();
293     }
294 
295     // The number of nodes in this cluster.
cluster_size() const296     int cluster_size() const { return pattern_.size(); }
297 
298     // The ID of the cluster as represented in `cycle_detector_`.
cycles_graph_node_id() const299     int cycles_graph_node_id() const { return node_id_; }
300 
301     // Sets the ID of the cluster as represented in `cycle_detector_`.
set_cycles_graph_node_id(int cycles_graph_node_id)302     void set_cycles_graph_node_id(int cycles_graph_node_id) {
303       node_id_ = cycles_graph_node_id;
304     }
305 
306     // Currently the fused pattern this cluster holds.
fused_pattern()307     const FusionPattern& fused_pattern() { return pattern_; }
308 
309    private:
310     // ID of the representative node of this cluster.
311     int node_id_;
312 
313     // the fused pattern this cluster holds.
314     FusionPattern pattern_;
315   };
316 
317  private:
MakeCluster(int cycles_graph_node_id)318   Cluster* MakeCluster(int cycles_graph_node_id) {
319     cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
320     return cluster_storage_.back().get();
321   }
322 
BuildNodeMap()323   void BuildNodeMap() {
324     int num_nodes = op_list_.size();
325     for (int node_id = 0; node_id < num_nodes; ++node_id) {
326       Operation* op = op_list_[node_id];
327       MakeCluster(node_id);
328       op_to_node_id_[op] = node_id;
329       leader_for_node_.insert(node_id);
330       for (Value operand : op->getOperands()) {
331         Operation* operand_op = operand.getDefiningOp();
332         if (operand_op == nullptr) {
333           // skip block argument
334           continue;
335         }
336         auto iter = op_to_node_id_.find(operand_op);
337         assert(iter != op_to_node_id_.end());
338         cycle_detector_.InsertEdge(iter->second, node_id);
339       }
340     }
341   }
342 
343   // Returns the cluster contains this op.
GetClusterForNode(Operation * n)344   Cluster* GetClusterForNode(Operation* n) {
345     int id = op_to_node_id_.at(n);
346     id = leader_for_node_.getLeaderValue(id);
347     return cluster_storage_[id].get();
348   }
349 
350   // Returns the cluster contains the op having `node_id`.
GetClusterForCyclesGraphNode(int node_id)351   Cluster* GetClusterForCyclesGraphNode(int node_id) {
352     return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
353   }
354 
355   // Merges the clusters `cluster_from` and `cluster_to`.
MergeClusters(Cluster * cluster_from,Cluster * cluster_to)356   bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
357     int from = cluster_from->cycles_graph_node_id();
358     int to = cluster_to->cycles_graph_node_id();
359 
360     auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
361     if (!optional_merged_node.hasValue()) {
362       llvm::dbgs() << "Could not contract " << from << " -> " << to
363                    << " because contracting the edge would create a cycle.";
364       return false;
365     }
366 
367     // Merge the clusters.
368     cluster_from->Merge(cluster_to);
369     cluster_from->set_cycles_graph_node_id(*optional_merged_node);
370 
371     // Merge the UnionFind Set.
372     leader_for_node_.unionSets(from, to);
373     return true;
374   }
375 
376   template <typename FnTy>
ForEachEdgeInPostOrder(FnTy fn)377   bool ForEachEdgeInPostOrder(FnTy fn) {
378     bool changed = false;
379     for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
380       Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
381       // Make a copy of the set of successors because we may modify the graph in
382       // TryToContractEdge.
383       std::vector<int32_t> successors_copy =
384           cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
385 
386       for (int to : successors_copy) {
387         Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
388         bool contracted_edge = fn(cluster_from, cluster_to);
389         changed |= contracted_edge;
390       }
391     }
392 
393     return changed;
394   }
395 
396   // returns the outputs if two cluster were merged
GetResultsOfFusedPattern(Cluster * from,Cluster * to)397   SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
398     FusionPattern fused_pattern =
399         MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
400     return GetOutputsOfFusionPattern(fused_pattern);
401   }
402 
403   // This function check if fusing `from` with `to` is valid and if so perform
404   // the merge. The validity is based on the operations in the clusters and
405   // the compatibility of the shapes of the outputs of the would-be fused
406   // clusters.
407   // Returns true is the merge was performed.
TryToContractEdge(Cluster * from,Cluster * to)408   bool TryToContractEdge(Cluster* from, Cluster* to) {
409     int node_to = to->cycles_graph_node_id();
410     int node_from = from->cycles_graph_node_id();
411 
412     // Both node_to and node_from should be fusible
413     if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
414       return false;
415     }
416 
417     auto op_from_fusibility =
418         dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
419     if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
420       // This op cannot be fused with its consumers.
421       return false;
422     }
423 
424     auto op_to_fusibility =
425         dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
426     if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
427       // This op cannot be fused with its operands.
428       return false;
429     }
430 
431     // Output shapes of a fusion pattern should be compatible as described in
432     // the document of this class.
433     SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
434     auto get_workload_shape = [](Value v) {
435       Operation* op = v.getDefiningOp();
436       // Block argument
437       if (!op) return v;
438       auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
439       // Const value
440       if (!op_fusibility) return v;
441       llvm::Optional<Value> workload =
442           op_fusibility.inferEffectiveWorkloadShape();
443       return workload.hasValue() ? *workload : v;
444     };
445 
446     Value ref = get_workload_shape(results[0]);
447     if (!llvm::all_of(results, [&](Value result) {
448           Value val = get_workload_shape(result);
449           return shape_analysis_.HasSameShape(ref, val);
450         })) {
451       return false;
452     }
453 
454     return MergeClusters(from, to);
455   }
456 
457   // Greedily fuse connected node.
RunEdgeContractionLoop()458   bool RunEdgeContractionLoop() {
459     using std::placeholders::_1;
460     using std::placeholders::_2;
461     return ForEachEdgeInPostOrder(
462         std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
463   }
464 
465   const SmallVectorImpl<Operation*>& op_list_;
466 
467   // Shape equality checker
468   ShapeConstraintAnalysis shape_analysis_;
469 
470   // op -> node_id
471   std::unordered_map<Operation*, int> op_to_node_id_;
472 
473   // make sure not introduce cycle after fusion
474   GraphCycles cycle_detector_;
475   std::vector<std::unique_ptr<Cluster>> cluster_storage_;
476 
477   // a UnionFind set. Each set represents a (partial) fused pattern
478   // and has a leader as representation.
479   EquivalenceClasses<int32_t> leader_for_node_;
480 };
481 
482 struct MhloFusionPass : public mlir::PassWrapper<MhloFusionPass, FunctionPass> {
runOnFunctionmlir::mhlo::__anon3c5f23730111::MhloFusionPass483   void runOnFunction() override {
484     FuncOp func = getFunction();
485     if (!IsTargetFunc(func)) {
486       return;
487     }
488 
489     // process each block and do fusion within a block.
490     for (Block& block : func) {
491       SmallVector<Operation*, 4> op_list;
492       for (Operation& op : block) {
493         op_list.push_back(&op);
494       }
495 
496       FusionPlanner planner(op_list);
497       llvm::Optional<FusionPlan> plan = planner.Run();
498       if (!plan) {
499         emitError(func.getLoc(), "can't find a fusion plan");
500         signalPassFailure();
501         return;
502       }
503       if (!ApplyFusionPlan(*plan)) {
504         emitError(func.getLoc(), "apply fusion plan failed");
505         signalPassFailure();
506         return;
507       }
508     }
509   }
510 
IsTargetFuncmlir::mhlo::__anon3c5f23730111::MhloFusionPass511   bool IsTargetFunc(FuncOp func) {
512     int num_fusible_ops = 0;
513     bool is_target_func = false;
514     // We only process the function having enough candidates
515     func.walk([&](Operation* op) {
516       num_fusible_ops +=
517           static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
518       is_target_func = (num_fusible_ops > 1);
519       // early stop
520       if (is_target_func) return WalkResult::interrupt();
521       return WalkResult::advance();
522     });
523     return is_target_func;
524   }
525 
ApplyFusionPlanmlir::mhlo::__anon3c5f23730111::MhloFusionPass526   bool ApplyFusionPlan(const FusionPlan& plan) {
527     for (const FusionPattern& pattern : plan) {
528       OpBuilder b(pattern.back());
529 
530       SmallVector<Location, 4> locations;
531       locations.reserve(pattern.size());
532       for (Operation* op : pattern) {
533         locations.push_back(op->getLoc());
534       }
535       Location fused_loc =
536           FusedLoc::get(locations, pattern.back()->getContext());
537 
538       SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
539       SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
540       SmallVector<Type, 4> output_types;
541       output_types.reserve(outputs.size());
542       for (Value v : outputs) {
543         output_types.push_back(v.getType());
544       }
545 
546       FusionOp fusion =
547           b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
548       Region& region = fusion.fused_computation();
549       region.push_back(new Block);
550       Block& block = region.front();
551       for (Operation* op : pattern) {
552         op->moveBefore(&block, block.end());
553       }
554       b.setInsertionPoint(&block, block.end());
555       b.create<mhlo::ReturnOp>(fused_loc, outputs);
556 
557       for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
558         Value output = std::get<0>(output_and_result);
559         Value fusion_result = std::get<1>(output_and_result);
560         for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
561           if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
562         }
563       }
564     }
565     return true;
566   }
567 };
568 
569 }  // namespace
570 
createMhloFusionPass()571 std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass() {
572   return std::make_unique<MhloFusionPass>();
573 }
574 
575 }  // namespace mhlo
576 }  // namespace mlir
577