1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/None.h"
17 #include "llvm/Support/raw_ostream.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
26 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
29 
30 namespace mlir {
31 namespace TFL {
32 namespace {
33 // Pass which removes any unused bounded function arguments which maps to
34 // variables, also removes the GlobalTensor which is the variable.
35 class RemoveArgsAndGlobalTensors
36     : public PassWrapper<RemoveArgsAndGlobalTensors, OperationPass<ModuleOp>> {
37  public:
38   RemoveArgsAndGlobalTensors() = default;
RemoveArgsAndGlobalTensors(const RemoveArgsAndGlobalTensors &)39   RemoveArgsAndGlobalTensors(const RemoveArgsAndGlobalTensors&) {}
40 
runOnOperation()41   void runOnOperation() override {
42     auto module = getOperation();
43     SymbolTable symbol_table(module);
44 
45     // Remove unused arguments in the functions which are bounded input
46     // for a global tensor. Also, removes the now unused global tensors.
47     std::set<mlir::tf_saved_model::GlobalTensorOp> global_tensors_to_remove;
48     for (auto func : module.getOps<FuncOp>()) {
49       llvm::SmallVector<unsigned int> index_to_remove;
50       for (int i = 0; i < func.getNumArguments(); ++i) {
51         if (auto sym = func.template getArgAttrOfType<FlatSymbolRefAttr>(
52                 i, "tf_saved_model.bound_input")) {
53           auto global_tensor =
54               symbol_table.lookup<tf_saved_model::GlobalTensorOp>(
55                   sym.getValue());
56           if (global_tensor && func.getArgument(i).getUsers().empty()) {
57             index_to_remove.push_back(i);
58             global_tensors_to_remove.insert(global_tensor);
59           }
60         }
61       }
62       func.eraseArguments(index_to_remove);
63     }
64     for (auto global_tensor : global_tensors_to_remove) {
65       global_tensor->erase();
66     }
67   }
68 };
69 
70 }  // namespace
71 
CreateRemoveArgsAndGlobalTensors()72 std::unique_ptr<OperationPass<ModuleOp>> CreateRemoveArgsAndGlobalTensors() {
73   return std::make_unique<RemoveArgsAndGlobalTensors>();
74 }
75 
76 static PassRegistration<RemoveArgsAndGlobalTensors> pass(
77     "tfl-remove-unused-function-args",
78     "Removes unused bounded input arguments to function which are unused and "
79     "maps to GlobalTensor.");
80 
81 }  // namespace TFL
82 }  // namespace mlir
83