1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/Support/Casting.h"
18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 
27 namespace mlir {
28 namespace mhlo {
29 
30 namespace {
31 
32 // A pass that sinks constants implicitly captured in control flow regions. This
33 // is necessary to export to XLA.
34 //
35 // TODO(hinsu): Generalize this pass to handle all the ops with regions. Any
36 // value used within the region that is defined outside of op's region should be
37 // sank to the regions and not just the constants. Ops such as If and While
38 // whose computations doesn't require fixed signature like Sort or Reduce have
39 // an option to pass outside values as operands of the op to avoid recomputing
40 // those within internally. Note that doing so is the only option in case of
41 // values defined outside that are BlockArguments of any of the parent region.
42 class SinkConstantsToControlFlowPass
43     : public SinkConstantsToControlFlowPassBase<
44           SinkConstantsToControlFlowPass> {
runOnFunction()45   void runOnFunction() override {
46     getFunction().walk([](Operation* op) {
47       if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
48         SinkToRegion(&while_op.body());
49         SinkToRegion(&while_op.cond());
50       } else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
51         SinkToRegion(&if_op.true_branch());
52         SinkToRegion(&if_op.false_branch());
53       } else if (auto reduce_window_op = llvm::dyn_cast<ReduceWindowOp>(op)) {
54         SinkToRegion(&reduce_window_op.body());
55       } else if (auto sort_op = llvm::dyn_cast<SortOp>(op)) {
56         SinkToRegion(&sort_op.comparator());
57       }
58     });
59   }
60 
61  private:
62   // Performs constant sinking into a region.
SinkToRegion(Region * region)63   static void SinkToRegion(Region* region) {
64     llvm::DenseMap<Value, Operation*> sunk_constant;
65     visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
66       Value constant = use->get();
67       auto op = constant.getDefiningOp();
68       if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return;
69       auto map_entry = sunk_constant.try_emplace(constant, nullptr);
70       if (!map_entry.second) {
71         // This constant has already been cloned into the region, reuse it.
72         use->set(map_entry.first->getSecond()->getResult(0));
73         if (op->use_empty()) op->erase();
74         return;
75       }
76       if (constant.hasOneUse()) {
77         op->moveBefore(&region->front().front());
78         return;
79       }
80       map_entry.first->getSecond() = op->clone();
81       region->front().getOperations().insert(region->front().begin(),
82                                              map_entry.first->getSecond());
83       use->set(map_entry.first->getSecond()->getResult(0));
84     });
85   }
86 };
87 
88 }  // anonymous namespace
89 
90 // TODO(hinsu): Rename this pass and move to a different file along with the
91 // generalization to make all ops isolated from above.
createSinkConstantsToControlFlowPass()92 std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
93   return std::make_unique<SinkConstantsToControlFlowPass>();
94 }
95 
96 }  // namespace mhlo
97 }  // namespace mlir
98