1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass promotes resource accesses in the main function to input arguments
17 // and outputs of the main function.
18 //
19 // Two types of resources are supported:
20 // (1) A function argument of TF::ResourceType type.
21 // (2) A VarHandleOp in the function.
22 //
23 // After the pass,
24 //
25 //  . The function will have an input argument for each resource that is
26 //    already provided as an input argument or is read. The type of the input
27 //    argument will become the shape of the value represented by the resource.
28 //
29 //  . The function will have an output for each resource that is written. The
30 //    type of the output will become the shape of the resource.
31 //
32 // The information of variable identification and input-output alising is
33 // recorded as named attributes of the input argument or output:
34 //
35 //  . 'tf.resource_name' matches 'shared_name' of VarHandleOp, which represents
36 //    the identifier of the corresponding resource. This attribute is added to
37 //    an input argument if the initial value of the resource is read, or to the
38 //    output if the initial value is not read.
39 //
40 //  . 'tf.aliasing_output' is the index of the function output that is an alias
41 //    of the input argument. This attribute is added only to the input argument
42 //    when the initial value of the corresponding resource is read, and the
43 //    resource is written later.
44 //
45 // Assumption of this pass:
46 //  . Compound resource operations have already been decomposed.
47 //  . Dead functions have already been removed, as resource arguments in dead
48 //    functions can cause the pass to fail.
49 
50 #include "llvm/ADT/ArrayRef.h"
51 #include "llvm/ADT/DenseMap.h"
52 #include "llvm/ADT/PointerUnion.h"
53 #include "llvm/ADT/STLExtras.h"
54 #include "llvm/ADT/SmallSet.h"
55 #include "llvm/ADT/SmallVector.h"
56 #include "llvm/ADT/StringExtras.h"
57 #include "llvm/ADT/StringRef.h"
58 #include "llvm/Support/Casting.h"
59 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
60 #include "mlir/IR/Attributes.h"  // from @llvm-project
61 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
62 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
63 #include "mlir/IR/Types.h"  // from @llvm-project
64 #include "mlir/IR/Value.h"  // from @llvm-project
65 #include "mlir/Pass/Pass.h"  // from @llvm-project
66 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
69 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
70 
71 namespace mlir {
72 namespace TF {
73 namespace {
74 
75 constexpr char kResourceFunctionMsg[] =
76     "expects function level resource argument";
77 constexpr char kInvalidResourceMsg[] =
78     "expects resource to be a VarHandleOp or function argument";
79 constexpr char kResourceNameArgAttr[] = "tf.resource_name";
80 
81 // Checks if a function has only one block.
CheckSingleBlockFunction(FuncOp function)82 mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) {
83   if (!llvm::hasSingleElement(function)) {
84     return function.emitError()
85            << "expects function '" << function.getName()
86            << "' to have 1 block, got " << function.getBlocks().size();
87   }
88   return success();
89 }
90 
91 // Collects names of users of a resource that are not `tf.ReadVariableOp` and
92 // not `tf.AssignVariableOp`.
GetCompositeResourceUserNames(Value resource)93 llvm::SmallSet<llvm::StringRef, 1> GetCompositeResourceUserNames(
94     Value resource) {
95   // SmallSet will use a vector when there is only one element and use std::set
96   // when there are more than one elements. This ensures that the operations in
97   // the error message are ordered.
98   llvm::SmallSet<llvm::StringRef, 1> composite_users;
99   for (Operation* user : resource.getUsers())
100     if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(user))
101       composite_users.insert(user->getName().getStringRef());
102 
103   return composite_users;
104 }
105 
106 // Checks that the only users of `tf.VarHandleOp` are
107 // `tf.ReadVariableOp` and `tf.AssignVariableOp`.
ValidateVarHandle(TF::VarHandleOp var_handle_op)108 mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) {
109   auto composite_ops = GetCompositeResourceUserNames(var_handle_op);
110   if (!composite_ops.empty())
111     return var_handle_op.emitOpError()
112            << "expects users to be 'tf.ReadVariableOp' or "
113               "'tf.AssignVariableOp', got ["
114            << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
115            << "]";
116 
117   return success();
118 }
119 
120 // Checks if resource argument has a valid resource subtype and its users are of
121 // `tf.ReadVariableOp` and `tf.AssignVariableOp` only.
ValidateResourceArgument(FuncOp function,BlockArgument resource_arg,TF::ResourceType resource_type)122 mlir::LogicalResult ValidateResourceArgument(FuncOp function,
123                                              BlockArgument resource_arg,
124                                              TF::ResourceType resource_type) {
125   if (resource_type.getSubtypes().size() != 1)
126     return function.emitError()
127            << "expects resource type of argument "
128            << resource_arg.getArgNumber() << " to have one subtype, got "
129            << resource_type;
130 
131   auto composite_ops = GetCompositeResourceUserNames(resource_arg);
132   if (!composite_ops.empty())
133     return function.emitError()
134            << "expects users of resource argument "
135            << resource_arg.getArgNumber()
136            << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got ["
137            << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
138            << "]";
139 
140   return success();
141 }
142 
143 // Adds resource arguments for every unique (name) variable handle. Associated
144 // `tf.VarHandleOp` are removed from the function. Variable shared names are
145 // returned in `var_handle_shared_names` based on the ordering of added resource
146 // arguments.
PromoteVarHandlesToArguments(FuncOp function,bool add_validation,llvm::SmallVectorImpl<std::string> * var_handle_shared_names)147 mlir::LogicalResult PromoteVarHandlesToArguments(
148     FuncOp function, bool add_validation,
149     llvm::SmallVectorImpl<std::string>* var_handle_shared_names) {
150   Block& block = function.front();
151   auto func_type = function.getType();
152 
153   auto func_arg_types = llvm::to_vector<4>(func_type.getInputs());
154   llvm::SmallDenseMap<llvm::StringRef, int> var_arg_index_by_name;
155   for (auto var_handle_op :
156        llvm::make_early_inc_range(block.getOps<TF::VarHandleOp>())) {
157     if (add_validation && failed(ValidateVarHandle(var_handle_op)))
158       return failure();
159 
160     llvm::StringRef name = var_handle_op.shared_nameAttr().getValue();
161     auto it = var_arg_index_by_name.insert({name, func_arg_types.size()});
162     if (it.second) {
163       var_handle_shared_names->emplace_back(name);
164       auto resource_type = var_handle_op.resource().getType();
165       func_arg_types.push_back(resource_type);
166       var_handle_op.resource().replaceAllUsesWith(
167           block.addArgument(resource_type));
168     } else {
169       var_handle_op.resource().replaceAllUsesWith(
170           block.getArgument(it.first->getSecond()));
171     }
172     var_handle_op.erase();
173   }
174 
175   if (!var_handle_shared_names->empty())
176     function.setType(FunctionType::get(function.getContext(), func_arg_types,
177                                        func_type.getResults()));
178 
179   return success();
180 }
181 
182 // Records the current live value for a resource variable and whether a read or
183 // write on the variable occurred.
184 struct ResourceInfo {
185   Value live_value = nullptr;
186   bool read = false;
187   bool write = false;
188 };
189 
PromoteResourcesToArguments(FuncOp function,llvm::ArrayRef<std::string> var_handle_shared_names)190 LogicalResult PromoteResourcesToArguments(
191     FuncOp function, llvm::ArrayRef<std::string> var_handle_shared_names) {
192   Block& block = function.front();
193 
194   auto return_op = llvm::dyn_cast_or_null<ReturnOp>(block.getTerminator());
195   if (!return_op)
196     return function.emitError() << "expects function '" << function.getName()
197                                 << "' to have a MLIR ReturnOp";
198 
199   llvm::SmallVector<ResourceInfo, 4> resources(function.getNumArguments());
200   auto argument_types = llvm::to_vector<4>(function.getType().getInputs());
201   bool has_resources = false;
202   auto add_resource_argument = [&](BlockArgument arg,
203                                    TF::ResourceType resource_type) {
204     Type arg_type = resource_type.getSubtypes().front();
205     arg.setType(arg_type);
206     resources[arg.getArgNumber()].live_value = arg;
207     argument_types[arg.getArgNumber()] = arg_type;
208     has_resources = true;
209   };
210 
211   // Loop through the non `tf.VarHandleOp` resource arguments in the function,
212   // validate its uses and subtype, and store a mapping from that argument to
213   // itself as the current live value.
214   auto func_args = function.getArguments().take_front(
215       function.getNumArguments() - var_handle_shared_names.size());
216   for (BlockArgument& func_arg : func_args) {
217     auto resource_type =
218         getElementTypeOrSelf(func_arg.getType()).dyn_cast<TF::ResourceType>();
219     if (!resource_type) continue;
220     if (failed(ValidateResourceArgument(function, func_arg, resource_type)))
221       return failure();
222 
223     add_resource_argument(func_arg, resource_type);
224   }
225 
226   // Loop through `tf.VarHandleOp` resource arguments in the function and store
227   // a mapping from that argument to itself as the current live value. No
228   // validations are necessary here as these arguments were validated prior to
229   // being added.
230   auto var_handle_args =
231       function.getArguments().take_back(var_handle_shared_names.size());
232   for (BlockArgument& var_handle_arg : var_handle_args) {
233     auto resource_type =
234         getElementTypeOrSelf(var_handle_arg.getType()).cast<TF::ResourceType>();
235     add_resource_argument(var_handle_arg, resource_type);
236   }
237 
238   if (!has_resources) return success();
239 
240   // We initially assign the argument for a resource as the live value for the
241   // resource. We then walk through the operations in the function in their
242   // lexical order, to update the live value for the resource when we see a
243   // store to the resource and replace reads of the resource with uses of its
244   // live value.
245   for (Operation& op : llvm::make_early_inc_range(block)) {
246     if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(&op)) {
247       if (auto func_arg = read_op.resource().dyn_cast<BlockArgument>()) {
248         if (func_arg.getOwner() != &block)
249           return read_op.emitOpError(kResourceFunctionMsg);
250 
251         ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
252         resource_info.read = true;
253         read_op.value().replaceAllUsesWith(resource_info.live_value);
254       } else {
255         return read_op.emitOpError(kInvalidResourceMsg);
256       }
257 
258       read_op.erase();
259     } else if (auto write_op = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
260       if (auto func_arg = write_op.resource().dyn_cast<BlockArgument>()) {
261         if (func_arg.getOwner() != &block)
262           return write_op.emitOpError(kResourceFunctionMsg);
263 
264         ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
265         resource_info.write = true;
266         resource_info.live_value = write_op.value();
267       } else {
268         return read_op.emitOpError(kInvalidResourceMsg);
269       }
270 
271       write_op.erase();
272     }
273   }
274 
275   const int64_t num_results_before = function.getNumResults();
276   auto return_operands = llvm::to_vector<4>(return_op.getOperands());
277   auto result_types = llvm::to_vector<4>(return_op.getOperandTypes());
278   llvm::SmallVector<std::pair<int64_t, llvm::StringRef>, 4>
279       output_only_resources;
280   llvm::SmallVector<std::pair<int64_t, int64_t>, 4> input_output_alias;
281 
282   // Collect new return values for variable writes and either (a) output-only
283   // resource attributes (if the resource is not promoted to an argument) or (b)
284   // mapping from resource input index to output alias (if the resource has been
285   // promoted to an argument). Resource arguments that were originally
286   // `tf.VarHandleOp` but not read are collected and then removed.
287   OpBuilder builder(return_op);
288   const int var_handles_start_idx =
289       function.getNumArguments() - var_handle_shared_names.size();
290   int new_argument_index = 0;
291   llvm::SmallVector<int, 4> argument_indices_to_remove;
292   for (auto resource_and_index : llvm::enumerate(resources)) {
293     const auto& resource = resource_and_index.value();
294     if (!resource.live_value) {
295       // Ignore non resource arguments.
296       ++new_argument_index;
297       continue;
298     }
299 
300     const int64_t index = resource_and_index.index();
301     const bool is_var_handle = index >= var_handles_start_idx;
302     if (resource.write) {
303       if (!is_var_handle || resource.read) {
304         input_output_alias.push_back(
305             {new_argument_index, return_operands.size()});
306       } else if (is_var_handle) {
307         output_only_resources.push_back(
308             {return_operands.size(),
309              var_handle_shared_names[index - var_handles_start_idx]});
310       }
311       return_operands.push_back(resource.live_value);
312       result_types.push_back(resource.live_value.getType());
313     }
314 
315     if (is_var_handle && !resource.read) {
316       assert(block.getArgument(index).getUses().empty());
317       argument_indices_to_remove.push_back(index);
318     } else {
319       if (is_var_handle) {
320         // Add resource_name attribute to VarHandleOp read.
321         function.setArgAttr(
322             new_argument_index, kResourceNameArgAttr,
323             builder.getStringAttr(
324                 var_handle_shared_names[index - var_handles_start_idx]));
325       }
326       ++new_argument_index;
327     }
328   }
329 
330   // Remove unread var handle arguments.
331   for (int argument_index_to_remove :
332        llvm::reverse(argument_indices_to_remove)) {
333     block.eraseArgument(argument_index_to_remove);
334     argument_types.erase(argument_types.begin() + argument_index_to_remove);
335   }
336 
337   // Rewrite return if there are variable writes.
338   const int return_operands_size = return_operands.size();
339   if (return_operands_size > num_results_before) {
340     builder.create<ReturnOp>(return_op.getLoc(), return_operands);
341     return_op.erase();
342   }
343 
344   // Update function argument and result types with new resource subtypes.
345   function.setType(builder.getFunctionType(argument_types, result_types));
346 
347   // Add resource_name attribute to the output for the resources.
348   for (auto& resource : output_only_resources)
349     function.setResultAttr(resource.first, kResourceNameArgAttr,
350                            builder.getStringAttr(resource.second));
351 
352   // Add aliasing_output attribute to the input argument for the resources that
353   // are updated by the function.
354   for (auto& input_output : input_output_alias)
355     function.setArgAttr(input_output.first, "tf.aliasing_output",
356                         builder.getI64IntegerAttr(input_output.second));
357 
358   return success();
359 }
360 
361 class PromoteResourcesToArgsPass
362     : public PassWrapper<PromoteResourcesToArgsPass, OperationPass<ModuleOp>> {
363  public:
364   void runOnOperation() override;
365 };
366 
runOnOperation()367 void PromoteResourcesToArgsPass::runOnOperation() {
368   ModuleOp module = getOperation();
369   FuncOp main_func = module.lookupSymbol<FuncOp>("main");
370   if (!main_func) return;
371 
372   // This routine should only be called when control flow operations are still
373   // represented with TF IfOp and WhileOp operations. In this case, there should
374   // be only one basic blocks in the MLIR representation.
375   if (failed(CheckSingleBlockFunction(main_func))) return signalPassFailure();
376 
377   llvm::SmallVector<std::string, 4> var_handle_shared_names;
378   if (failed(ResourceLiftingForFunctionalControlFlow(main_func)) ||
379       failed(PromoteVarHandlesToArguments(main_func, /*add_validation=*/true,
380                                           &var_handle_shared_names)) ||
381       failed(PromoteResourcesToArguments(main_func, var_handle_shared_names)))
382     return signalPassFailure();
383 }
384 
385 class PromoteVarHandlesToArgsPass
386     : public PassWrapper<PromoteVarHandlesToArgsPass, OperationPass<ModuleOp>> {
387  public:
388   void runOnOperation() override;
389 };
390 
runOnOperation()391 void PromoteVarHandlesToArgsPass::runOnOperation() {
392   ModuleOp module = getOperation();
393   MLIRContext* context = module.getContext();
394   for (auto function : module.getOps<FuncOp>()) {
395     if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
396 
397     llvm::SmallVector<std::string, 4> var_handle_shared_names;
398     (void)PromoteVarHandlesToArguments(function, /*add_validation=*/false,
399                                        &var_handle_shared_names);
400 
401     // Add resource names for each `tf.VarHandleOp` that were promoted to
402     // resource arguments.
403     const int var_handle_args_offset =
404         function.getNumArguments() - var_handle_shared_names.size();
405     for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
406       function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
407                           kResourceNameArgAttr,
408                           StringAttr::get(context, var_name_and_index.value()));
409   }
410 }
411 
412 }  // namespace
413 
CreatePromoteResourcesToArgsPass()414 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
415   return std::make_unique<PromoteResourcesToArgsPass>();
416 }
417 
CreatePromoteVarHandlesToArgsPass()418 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() {
419   return std::make_unique<PromoteVarHandlesToArgsPass>();
420 }
421 
422 static PassRegistration<PromoteResourcesToArgsPass> pass(
423     "tf-promote-resources-to-args",
424     "Promote resources reads/writes to function inputs/outputs.");
425 
426 static PassRegistration<PromoteVarHandlesToArgsPass> var_handle_pass(
427     "tf-promote-var-handles-to-args",
428     "Promote tf.VarHandleOps to function arguments.");
429 
430 }  // namespace TF
431 }  // namespace mlir
432