1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass optimizes tf_saved_model.global_tensor ops.
17 
18 #include <cstddef>
19 #include <map>
20 #include <set>
21 
22 #include "llvm/ADT/DenseMap.h"
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Operation.h"  // from @llvm-project
28 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
29 #include "mlir/IR/Types.h"  // from @llvm-project
30 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
38 
39 namespace mlir {
40 namespace tf_saved_model {
41 namespace {
42 struct OptimizeGlobalTensorsPass
43     : public PassWrapper<OptimizeGlobalTensorsPass, OperationPass<ModuleOp>> {
44   void runOnOperation() override;
45 };
46 
47 // A global tensor is bound to arguments of multiple funcs.
48 // This struct tracks which funcs (and which argument to that func) the global
49 // tensor is bound to.
50 struct GlobalTensorUse {
51   mutable FuncOp func;
52   size_t arg_index;
53 };
54 
55 using GlobalTensorUsesMap =
56     std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
57 
IsResourceType(Type type)58 bool IsResourceType(Type type) {
59   if (auto tensor_type = type.dyn_cast<TensorType>()) {
60     return tensor_type.getElementType().isa<TF::ResourceType>();
61   }
62   return false;
63 }
64 
IsResource(Value value)65 bool IsResource(Value value) { return IsResourceType(value.getType()); }
66 
67 class ResourceAnalyzer {
68  public:
ResourceAnalyzer(ModuleOp module)69   explicit ResourceAnalyzer(ModuleOp module) {
70     for (auto func : module.getOps<FuncOp>()) {
71       (void)AnalyzeFunc(func);
72     }
73   }
74 
IsPotentiallyWritten(Value resource) const75   bool IsPotentiallyWritten(Value resource) const {
76     assert(IsResource(resource));
77     auto it = resource_infos_.find(resource);
78     if (it == resource_infos_.end()) {
79       return false;
80     }
81     return it->second.potentially_written;
82   }
83 
84  private:
85   // Analyze the specified func for resource mutating operations, namely
86   // TF::AssignVariableOp, if so, set the resource associated as "potentially
87   // written". Do this recursively across the chain of funcs via call or control
88   // flow ops.
89   // TODO(ashwinm): Move to iterative traversal.
AnalyzeFunc(FuncOp func)90   LogicalResult AnalyzeFunc(FuncOp func) {
91     // Avoid infinite recursion.
92     if (!discovered_.insert(func).second) {
93       return success();
94     }
95 
96     func.walk([&](Operation* op) {
97       if (isa<TF::ReadVariableOp, ReturnOp>(op)) {
98         return;
99       }
100       if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
101         SetPotentiallyWritten(assign_variable.resource());
102         return;
103       }
104       if (auto call = dyn_cast<CallOpInterface>(op)) {
105         if (auto func = dyn_cast<FuncOp>(call.resolveCallable())) {
106           PropagatePotentiallyWrittenUpFromCallee(func, call.getArgOperands());
107         }
108         return;
109       }
110       if (auto if_op = dyn_cast<TF::IfOp>(op)) {
111         for (auto callee : {if_op.then_function(), if_op.else_function()}) {
112           PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input());
113         }
114         return;
115       }
116       if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
117         for (auto callee :
118              {while_op.cond_function(), while_op.body_function()}) {
119           PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input());
120         }
121         return;
122       }
123       // For all other ops, we assume it mutates all resources it uses, so
124       // this errs on the side of being conservative. We should improve
125       // this by using either a property or a trait that clearly
126       // identifies ops with resource mutating behavior.
127       PropagatePotentiallyWrittenWithinUnhandledOp(op);
128     });
129     return success();
130   }
131 
132   // If an op is not one of the handled ones, we assume all resource usages
133   // within its purview are mutating in nature.
PropagatePotentiallyWrittenWithinUnhandledOp(Operation * op)134   void PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) {
135     for (auto operand : op->getOperands()) {
136       if (IsResource(operand)) {
137         SetPotentiallyWritten(operand);
138       }
139     }
140     visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
141       if (IsResource(operand->get())) {
142         SetPotentiallyWritten(operand->get());
143       }
144     });
145   }
146 
147   // Given a FuncOp associated with the callee and operands from the
148   // corresponding callOp, propagate the potentially written decision to the
149   // callOp's operands, if the corresponding func's arguments are potentially
150   // written resources.
PropagatePotentiallyWrittenUpFromCallee(FuncOp func,Operation::operand_range propagate_to)151   void PropagatePotentiallyWrittenUpFromCallee(
152       FuncOp func, Operation::operand_range propagate_to) {
153     (void)AnalyzeFunc(func);
154     for (auto t : llvm::zip(func.getArguments(), propagate_to)) {
155       if (!IsResource(std::get<0>(t))) {
156         continue;
157       }
158       if (IsPotentiallyWritten(std::get<0>(t))) {
159         SetPotentiallyWritten(std::get<1>(t));
160       }
161     }
162   }
163 
SetPotentiallyWritten(Value resource)164   void SetPotentiallyWritten(Value resource) {
165     assert(IsResource(resource));
166     resource_infos_[resource].potentially_written = true;
167   }
168   struct ResourceInfo {
169     bool potentially_written = false;
170   };
171   // Key: Resource Value's
172   // Value: Information we know about that Value.
173   // Note that these Value's are in general in different functions.
174   DenseMap<Value, ResourceInfo> resource_infos_;
175   // The set of func's we already discovered.
176   DenseSet<FuncOp> discovered_;
177 };
178 
IsImmutable(GlobalTensorOp global_tensor,ArrayRef<GlobalTensorUse> global_tensor_uses,const ResourceAnalyzer & resource_analyzer)179 bool IsImmutable(GlobalTensorOp global_tensor,
180                  ArrayRef<GlobalTensorUse> global_tensor_uses,
181                  const ResourceAnalyzer& resource_analyzer) {
182   // Global tensor is already known to be immutable.
183   if (!global_tensor.is_mutable()) {
184     return false;
185   }
186   // An exported global tensor that is not already known to be immutable might
187   // be externally mutated.
188   if (IsExported(global_tensor)) {
189     return false;
190   }
191 
192   // A global tensor is immutable if the resource analyzer deems it so.
193   for (auto& global_tensor_use : global_tensor_uses) {
194     auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
195     if (resource_analyzer.IsPotentiallyWritten(arg)) {
196       return false;
197     }
198   }
199   return true;
200 }
201 
CreateGlobalTensorUsesMap(ModuleOp module)202 GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
203   GlobalTensorUsesMap global_tensor_uses;
204 
205   SymbolTable symbol_table(module);
206   for (auto func : module.getOps<FuncOp>()) {
207     for (size_t i = 0, e = func.getNumArguments(); i < e; i++) {
208       auto sym =
209           func.getArgAttrOfType<SymbolRefAttr>(i, "tf_saved_model.bound_input");
210       if (!sym) {
211         continue;
212       }
213       auto global_tensor = symbol_table.lookup<GlobalTensorOp>(
214           sym.cast<FlatSymbolRefAttr>().getValue());
215       if (!global_tensor) {
216         continue;
217       }
218       global_tensor_uses[global_tensor].push_back({func, i});
219     }
220   }
221 
222   return global_tensor_uses;
223 }
224 
225 // Removes `is_mutable` attribute from tf_saved_model.global_tensor ops where we
226 // can prove it is safe to do so.
MarkGlobalTensorsImmutable(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses_map,const ResourceAnalyzer & resource_analyzer)227 void MarkGlobalTensorsImmutable(
228     ModuleOp module, const GlobalTensorUsesMap& global_tensor_uses_map,
229     const ResourceAnalyzer& resource_analyzer) {
230   for (const auto& kv : global_tensor_uses_map) {
231     auto global_tensor = kv.first;
232     const auto& global_tensor_uses = kv.second;
233     if (IsImmutable(global_tensor, global_tensor_uses, resource_analyzer)) {
234       global_tensor.removeAttr("is_mutable");
235     }
236   }
237 }
238 
EraseUnusedGlobalTensors(ModuleOp module,const GlobalTensorUsesMap & global_tensor_uses)239 void EraseUnusedGlobalTensors(ModuleOp module,
240                               const GlobalTensorUsesMap& global_tensor_uses) {
241   for (auto global_tensor :
242        llvm::make_early_inc_range(module.getOps<GlobalTensorOp>())) {
243     // If the tensor is exported, then it is used.
244     if (IsExported(global_tensor)) {
245       continue;
246     }
247     // If the tensor is bound to an argument, then it is used.
248     if (global_tensor_uses.find(global_tensor) != global_tensor_uses.end()) {
249       continue;
250     }
251     // Erase it.
252     global_tensor.erase();
253   }
254 }
255 
EraseUnusedBoundInputs(ModuleOp module)256 void EraseUnusedBoundInputs(ModuleOp module) {
257   for (auto func : module.getOps<FuncOp>()) {
258     SmallVector<unsigned, 4> args_to_erase;
259     for (int i = 0, e = func.getNumArguments(); i < e; i++) {
260       if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
261           func.getArgument(i).use_empty()) {
262         args_to_erase.push_back(i);
263       }
264     }
265     func.eraseArguments(args_to_erase);
266   }
267 }
268 
runOnOperation()269 void OptimizeGlobalTensorsPass::runOnOperation() {
270   auto module = getOperation();
271   if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
272     return;
273   }
274 
275   EraseUnusedBoundInputs(module);
276 
277   ResourceAnalyzer resource_analyzer(module);
278 
279   GlobalTensorUsesMap global_tensor_uses = CreateGlobalTensorUsesMap(module);
280 
281   MarkGlobalTensorsImmutable(module, global_tensor_uses, resource_analyzer);
282 
283   EraseUnusedGlobalTensors(module, global_tensor_uses);
284 }
285 
286 // For "opt" to pick up this pass.
287 PassRegistration<OptimizeGlobalTensorsPass> pass(
288     "tf-saved-model-optimize-global-tensors",
289     "Optimize tf_saved_model.global_tensor's.");
290 
291 }  // namespace
292 
CreateOptimizeGlobalTensorsPass()293 std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
294   return std::make_unique<OptimizeGlobalTensorsPass>();
295 }
296 
297 }  // namespace tf_saved_model
298 }  // namespace mlir
299