1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/None.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
26 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
31 
32 namespace mlir {
33 namespace TFL {
34 namespace {
35 constexpr char kTfSavedModelSessionInitNameAttr[] =
36     "__tf_saved_model_session_initializer";
37 constexpr char kTfSavedModelExportedNameAttr[] =
38     "tf_saved_model.exported_names";
39 
40 // Returns Value representing the resource_id.
GetResourceIDAsI32(int resource_id,Location loc,mlir::OpBuilder & rewriter)41 Value GetResourceIDAsI32(int resource_id, Location loc,
42                          mlir::OpBuilder& rewriter) {
43   return rewriter.create<ConstOp>(
44       loc, DenseElementsAttr::get(
45                RankedTensorType::get({1}, rewriter.getIntegerType(32)),
46                resource_id));
47 }
48 
49 // Helper method that fetches the global tensor that 'op' points to it.
50 template <typename T>
GetGlobalTensor(const SymbolTable & symbol_table,T op,FuncOp func)51 tf_saved_model::GlobalTensorOp GetGlobalTensor(const SymbolTable& symbol_table,
52                                                T op, FuncOp func) {
53   auto block_arg = op.resource().template dyn_cast<BlockArgument>();
54   if (!block_arg) return nullptr;
55   int index = block_arg.getArgNumber();
56   auto sym = func.template getArgAttrOfType<FlatSymbolRefAttr>(
57       index, "tf_saved_model.bound_input");
58   if (!sym) {
59     return nullptr;
60   }
61   return symbol_table.lookup<tf_saved_model::GlobalTensorOp>(sym.getValue());
62 }
63 
64 // Pass which Initializes TF variables which are already passed as bounded
65 // arguments to functions, to a TFLite variables.
66 class InitializeVariablesPass
67     : public PassWrapper<InitializeVariablesPass, OperationPass<ModuleOp>> {
68  public:
69   InitializeVariablesPass() = default;
InitializeVariablesPass(const InitializeVariablesPass &)70   InitializeVariablesPass(const InitializeVariablesPass&) {}
71 
72   // Initializes a single variable identified by 'var_id' with value 'value'
73   // in 'session_init' function.
InitializeVariable(int var_id,ElementsAttr value,FuncOp session_init)74   void InitializeVariable(int var_id, ElementsAttr value, FuncOp session_init) {
75     // TODO(b/149099381): Initialize using TF::AssignVariableOp instead
76     // and let legalization be handled by Legalize variables pass.
77     mlir::OpBuilder builder(&getContext());
78     builder.setInsertionPoint(&session_init.getBlocks().front().front());
79     auto resource_op =
80         GetResourceIDAsI32(var_id, session_init.body().getLoc(), builder);
81     auto value_op =
82         builder.create<ConstOp>(session_init.body().getLoc(), value);
83     builder.create<TFL::AssignVariableOp>(session_init.body().getLoc(),
84                                           resource_op, value_op);
85   }
86 
GetGlobalTensorOp(mlir::Operation * op,SymbolTable symbol_table,FuncOp func)87   tf_saved_model::GlobalTensorOp GetGlobalTensorOp(mlir::Operation* op,
88                                                    SymbolTable symbol_table,
89                                                    FuncOp func) {
90     if (auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(op))
91       return GetGlobalTensor<TF::ReadVariableOp>(symbol_table, read_var, func);
92     else if (auto assign_var = llvm::dyn_cast_or_null<TF::AssignVariableOp>(op))
93       return GetGlobalTensor<TF::AssignVariableOp>(symbol_table, assign_var,
94                                                    func);
95     return nullptr;
96   }
97 
98   // Initializes all variables in the module.
InitializeVariables(const std::map<std::string,int> & global_tensor_id,SymbolTable symbol_table)99   void InitializeVariables(const std::map<std::string, int>& global_tensor_id,
100                            SymbolTable symbol_table) {
101     auto module = getOperation();
102     // Check if there is Session init func already, if not create one.
103     FuncOp session_init_func = nullptr;
104     for (auto func : module.getOps<FuncOp>()) {
105       if (auto attr = func->getAttr(kTfSavedModelExportedNameAttr)) {
106         auto exported_names = attr.dyn_cast<ArrayAttr>();
107         if (!exported_names) continue;
108         for (auto exported_name : exported_names) {
109           if (auto name = exported_name.dyn_cast_or_null<StringAttr>())
110             if (name.getValue() == kTfSavedModelSessionInitNameAttr)
111               session_init_func = func;
112         }
113         if (session_init_func) break;
114       }
115     }
116     // TODO(b/149099381): Refactor to separate function in saved model util.
117     if (!session_init_func) session_init_func = CreateSessionInitFunc();
118 
119     std::set<tf_saved_model::GlobalTensorOp> tensors_to_initialize;
120     for (auto func : module.getOps<FuncOp>()) {
121       func->walk([&](Operation* op) {
122         // TODO(b/149099381): Make sure to verify flex compatability
123         // with ops that accepts resource as input.
124         if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(op))
125           return WalkResult::advance();
126         tensors_to_initialize.insert(GetGlobalTensorOp(op, symbol_table, func));
127         return WalkResult::advance();
128       });
129     }
130     for (auto global_tensor : tensors_to_initialize) {
131       InitializeVariable(global_tensor_id.at(global_tensor.sym_name().str()),
132                          global_tensor.value(), session_init_func);
133     }
134   }
135   // Create a new function in the module which is SessionInitializerOp.
CreateSessionInitFunc()136   FuncOp CreateSessionInitFunc() {
137     constexpr char kSessionInitFuncName[] = "SessionInitializerFunction";
138     auto module = getOperation();
139 
140     mlir::OpBuilder builder(module.body());
141     auto func_type = FunctionType::get(&getContext(), {}, {});
142     auto func = builder.create<FuncOp>(module->getLoc(), kSessionInitFuncName,
143                                        func_type);
144     func->setAttr(kTfSavedModelExportedNameAttr,
145                   builder.getStrArrayAttr({kSessionInitFuncName}));
146     func.setVisibility(mlir::FuncOp::Visibility::Public);
147     auto funcBuilder = OpBuilder::atBlockBegin(func.addEntryBlock());
148     funcBuilder.create<mlir::ReturnOp>(func.getLoc());
149     builder.create<tf_saved_model::SessionInitializerOp>(
150         module->getLoc(),
151         builder.getArrayAttr(builder.getSymbolRefAttr(kSessionInitFuncName)));
152     return func;
153   }
154 
runOnOperation()155   void runOnOperation() override {
156     auto module = getOperation();
157     // Use ordered container to make sure ids are deterministic if we got tensor
158     // ids from different part, since we have different passes that touches
159     // variables.
160     // TODO(b/149099381): Remove integer IDs after adding the new variable
161     // handle type.
162     std::map<std::string, int> global_tensor_id;
163     int id = 0;
164     for (auto global_tensor : module.getOps<tf_saved_model::GlobalTensorOp>()) {
165       global_tensor_id[global_tensor.sym_name().str()];
166     }
167     for (auto& tensor : global_tensor_id) tensor.second = id++;
168     SymbolTable symbol_table(module);
169 
170     // Initialize all variables.
171     InitializeVariables(global_tensor_id, symbol_table);
172   }
173 };
174 }  // namespace
175 
CreateInitializeVariablesPass()176 std::unique_ptr<OperationPass<ModuleOp>> CreateInitializeVariablesPass() {
177   return std::make_unique<InitializeVariablesPass>();
178 }
179 
180 static PassRegistration<InitializeVariablesPass> pass(
181     "tfl-initialize-variables-tf",
182     "Initialize TensorFlow variables to TensorFlow Lite dialect");
183 
184 }  // namespace TFL
185 }  // namespace mlir
186