1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/None.h"
17 #include "llvm/Support/Casting.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
20 #include "mlir/IR/Attributes.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
23 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
24 #include "mlir/IR/Operation.h"  // from @llvm-project
25 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
29 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
33 
34 namespace mlir {
35 namespace TFL {
36 namespace {
37 // This file has Legalize variable pass which is responsible for:
38 // - Converting all TF::ReadVariableOp and TF::AssignVariableOp to the
39 //   TFLite equivalent ops.
40 // Note that, this pass assumes all variables are already available as
41 // GlobalTensorOp and all varHandle are converted already to a function
42 // arguments with bounded_input attribute.
43 // Also all other ops are already legalized to TFLite.
44 // TODO(b/149099381): Handle flex support use cases.
45 
46 // Returns Value representing the resource_id.
GetResourceIDAsI32(int resource_id,Location loc,mlir::OpBuilder & rewriter)47 Value GetResourceIDAsI32(int resource_id, Location loc,
48                          mlir::OpBuilder& rewriter) {
49   return rewriter.create<ConstOp>(
50       loc, DenseElementsAttr::get(
51                RankedTensorType::get({1}, rewriter.getIntegerType(32)),
52                resource_id));
53 }
54 
55 // Helper method that fetches the global tensor that 'op' points to it.
56 template <typename T>
GetGlobalTensor(const SymbolTable & symbol_table,T op,FuncOp func)57 tf_saved_model::GlobalTensorOp GetGlobalTensor(const SymbolTable& symbol_table,
58                                                T op, FuncOp func) {
59   auto block_arg = op.resource().template dyn_cast<BlockArgument>();
60   if (!block_arg) return nullptr;
61   int index = block_arg.getArgNumber();
62   auto sym = func.template getArgAttrOfType<FlatSymbolRefAttr>(
63       index, "tf_saved_model.bound_input");
64   if (!sym) {
65     return nullptr;
66   }
67   return symbol_table.lookup<tf_saved_model::GlobalTensorOp>(sym.getValue());
68 }
69 
GetAssignVariableOp(int variable_id,TF::AssignVariableOp assign_op,mlir::OpBuilder builder)70 mlir::Operation* GetAssignVariableOp(int variable_id,
71                                      TF::AssignVariableOp assign_op,
72                                      mlir::OpBuilder builder) {
73   return builder.create<TFL::AssignVariableOp>(
74       assign_op.getLoc(),
75       GetResourceIDAsI32(variable_id, assign_op.getLoc(), builder),
76       assign_op.value());
77 }
78 
GetReadVariableOp(int variable_id,TF::ReadVariableOp read_op,mlir::OpBuilder builder)79 mlir::Operation* GetReadVariableOp(int variable_id, TF::ReadVariableOp read_op,
80                                    mlir::OpBuilder builder) {
81   return builder.create<TFL::ReadVariableOp>(
82       read_op.getLoc(), read_op.getResult().getType(),
83       GetResourceIDAsI32(variable_id, read_op.getLoc(), builder));
84 }
85 
86 template <typename T>
87 class LegalizeVariablesPattern : public mlir::OpConversionPattern<T> {
88  public:
LegalizeVariablesPattern(mlir::MLIRContext * context,const std::map<std::string,int> * global_tensor_id,SymbolTable symbol_table)89   LegalizeVariablesPattern(mlir::MLIRContext* context,
90                            const std::map<std::string, int>* global_tensor_id,
91                            SymbolTable symbol_table)
92       : mlir::OpConversionPattern<T>(context),
93         global_tensor_id_(global_tensor_id),
94         symbol_table_(symbol_table) {}
95 
matchAndRewrite(T var_op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const96   LogicalResult matchAndRewrite(
97       T var_op, ArrayRef<Value> operands,
98       ConversionPatternRewriter& rewriter) const override {
99     auto* op = var_op.getOperation();
100     auto func = var_op->template getParentOfType<FuncOp>();
101     if (!func) return failure();
102     auto global_tensor = GetGlobalTensor<T>(symbol_table_, var_op, func);
103     if (!global_tensor) return failure();
104     auto variable_id = global_tensor_id_->at(global_tensor.sym_name().str());
105     mlir::OpBuilder builder(var_op);
106     mlir::Operation* tfl_var_op = nullptr;
107     if (llvm::isa<TF::AssignVariableOp>(op)) {
108       auto assign_op = llvm::cast<TF::AssignVariableOp>(op);
109       tfl_var_op = GetAssignVariableOp(variable_id, assign_op, builder);
110     } else {
111       auto read_op = llvm::cast<TF::ReadVariableOp>(op);
112       tfl_var_op = GetReadVariableOp(variable_id, read_op, builder);
113     }
114     var_op->replaceAllUsesWith(tfl_var_op);
115     rewriter.eraseOp(var_op);
116     return success();
117   }
118 
119  private:
120   const std::map<std::string, int>* global_tensor_id_;
121   SymbolTable symbol_table_;
122 };
123 
124 // Pass which legalizes TF variables which are already passed as bounded
125 // arguments to functions, to TFLite variables.
126 class LegalizeVariables
127     : public PassWrapper<LegalizeVariables, OperationPass<ModuleOp>> {
128  public:
129   LegalizeVariables() = default;
LegalizeVariables(const LegalizeVariables &)130   LegalizeVariables(const LegalizeVariables&) {}
131 
runOnOperation()132   void runOnOperation() override {
133     auto module = getOperation();
134     // Use ordered container to make sure ids are deterministic if we got tensor
135     // ids from different part, also easier to debug.
136     // TODO(b/149099381): Remove integer IDs after adding the new variable
137     // handle type.
138     std::map<std::string, int> global_tensor_id;
139     for (auto global_tensor : module.getOps<tf_saved_model::GlobalTensorOp>()) {
140       global_tensor_id[global_tensor.sym_name().str()];
141     }
142     int id = 0;
143     for (auto& tensor : global_tensor_id) tensor.second = id++;
144 
145     SymbolTable symbol_table(module);
146     ConversionTarget target(getContext());
147     OwningRewritePatternList patterns;
148     patterns.insert<LegalizeVariablesPattern<TF::ReadVariableOp>,
149                     LegalizeVariablesPattern<TF::AssignVariableOp>>(
150         &getContext(), &global_tensor_id, symbol_table);
151     if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
152       signalPassFailure();
153       return;
154     }
155   }
156 };
157 
158 }  // namespace
159 
CreateLegalizeVariablesPass()160 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeVariablesPass() {
161   return std::make_unique<LegalizeVariables>();
162 }
163 
164 static PassRegistration<LegalizeVariables> pass(
165     "tfl-legalize-variables-tf",
166     "Legalize TensorFlow variables to TensorFlow Lite dialect");
167 
168 }  // namespace TFL
169 }  // namespace mlir
170