1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/STLExtras.h"
17 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
18 #include "mlir/Pass/Pass.h"  // from @llvm-project
19 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
20 #include "mlir/Support/LLVM.h"  // from @llvm-project
21 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
22 #include "mlir/Transforms/Utils.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
24 
25 namespace mlir {
26 namespace TF {
27 
28 namespace {
29 
30 // Clones FuncOp's until they have a single use only (or no users).
31 //
32 // The tf-shape-inference pass doesn't support functions that have more than
33 // a single use. But some real code from frontends does end up creating code
34 // like that. For example, the same LSTM cell function or loop body function
35 // will be reused.
36 //
37 // This pass clones functions as needed to establish the invariant that all
38 // functions have a single use. This can in principle cause exponential code
39 // size bloat, and should in general be guided by a proper cost model.
40 //
41 // There are two factors which should be considered by a principled replacement
42 // to this pass:
43 //
44 // 1. TF currently relies on "sufficiently good shape inference" for
45 // correctness so for now the cost of doing this seems acceptable since
46 // pathological cases haven't hit us yet.
47 //
48 // 2. Cloning functions can help by allowing code to be specialized (much as
49 // inlining does). In fact, tf-shape-inference attempts to do specialization
50 // of callees which is difficult if callees have multiple uses.
51 class GuaranteeAllFuncsOneUse
52     : public PassWrapper<GuaranteeAllFuncsOneUse, OperationPass<ModuleOp>> {
53  public:
runOnOperation()54   void runOnOperation() override {
55     if (failed(Run())) {
56       signalPassFailure();
57     }
58   }
59 
Run()60   LogicalResult Run() {
61     auto module = getOperation();
62 
63     // Overall strategy:
64     // Fixed point iteration, iteratively applying a rule that clones
65     // any FuncOp with more than one use to eliminate its uses.
66 
67     SymbolTable symbol_table(module);
68     bool made_changes = false;
69     // This value needs to be low enough to actually stop compilation in a
70     // reasonable time, but not too low that it blocks real programs.
71     // This number was chosen semi-randomly.
72     const int k_max_clones = 1000;
73     int num_clones = 0;
74     do {
75       made_changes = false;
76       for (auto func : llvm::make_early_inc_range(module.getOps<FuncOp>())) {
77         auto uses_optional = symbol_table.getSymbolUses(func, module);
78         if (!uses_optional.hasValue()) {
79           return func.emitError() << "could not walk uses of func";
80         }
81         auto &uses = *uses_optional;
82         if (llvm::size(uses) <= 1) {
83           continue;
84         }
85         // At this point, we know we are going to change the module.
86         made_changes = true;
87         for (const SymbolTable::SymbolUse &use : llvm::drop_begin(uses, 1)) {
88           if (num_clones++ > k_max_clones) {
89             return func.emitError()
90                    << "reached cloning limit (likely recursive call graph or "
91                       "repeated diamond-like call structure "
92                       "or just very large program)";
93           }
94           auto new_func = func.clone();
95           symbol_table.insert(new_func);
96           new_func.setPrivate();
97           if (failed(symbol_table.replaceAllSymbolUses(func, new_func.getName(),
98                                                        use.getUser()))) {
99             return func.emitError() << "could not replace symbol use";
100           }
101         }
102       }
103     } while (made_changes);
104 
105     return success();
106   }
107 };
108 
109 }  // namespace
110 
CreateGuaranteeAllFuncsOneUsePass()111 std::unique_ptr<OperationPass<ModuleOp>> CreateGuaranteeAllFuncsOneUsePass() {
112   return std::make_unique<GuaranteeAllFuncsOneUse>();
113 }
114 
115 static PassRegistration<GuaranteeAllFuncsOneUse> pass(
116     "tf-guarantee-all-funcs-one-use",
117     "Guarantee all FuncOp's have only a single use.");
118 
119 }  // namespace TF
120 
121 }  // namespace mlir
122