1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass promotes resource accesses in the main function to input arguments
17 // and outputs of the main function.
18 //
19 // Two types of resources are supported:
20 // (1) A function argument of TF::ResourceType type.
21 // (2) A VarHandleOp in the function.
22 //
23 // After the pass,
24 //
25 // . The function will have an input argument for each resource that is
26 // already provided as an input argument or is read. The type of the input
27 // argument will become the shape of the value represented by the resource.
28 //
29 // . The function will have an output for each resource that is written. The
30 // type of the output will become the shape of the resource.
31 //
32 // The information of variable identification and input-output alising is
33 // recorded as named attributes of the input argument or output:
34 //
35 // . 'tf.resource_name' matches 'shared_name' of VarHandleOp, which represents
36 // the identifier of the corresponding resource. This attribute is added to
37 // an input argument if the initial value of the resource is read, or to the
38 // output if the initial value is not read.
39 //
40 // . 'tf.aliasing_output' is the index of the function output that is an alias
41 // of the input argument. This attribute is added only to the input argument
42 // when the initial value of the corresponding resource is read, and the
43 // resource is written later.
44 //
45 // Assumption of this pass:
46 // . Compound resource operations have already been decomposed.
47 // . Dead functions have already been removed, as resource arguments in dead
48 // functions can cause the pass to fail.
49
50 #include "llvm/ADT/ArrayRef.h"
51 #include "llvm/ADT/DenseMap.h"
52 #include "llvm/ADT/PointerUnion.h"
53 #include "llvm/ADT/STLExtras.h"
54 #include "llvm/ADT/SmallSet.h"
55 #include "llvm/ADT/SmallVector.h"
56 #include "llvm/ADT/StringExtras.h"
57 #include "llvm/ADT/StringRef.h"
58 #include "llvm/Support/Casting.h"
59 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
60 #include "mlir/IR/Attributes.h" // from @llvm-project
61 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
62 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
63 #include "mlir/IR/Types.h" // from @llvm-project
64 #include "mlir/IR/Value.h" // from @llvm-project
65 #include "mlir/Pass/Pass.h" // from @llvm-project
66 #include "mlir/Support/LogicalResult.h" // from @llvm-project
67 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
68 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
69 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
70
71 namespace mlir {
72 namespace TF {
73 namespace {
74
75 constexpr char kResourceFunctionMsg[] =
76 "expects function level resource argument";
77 constexpr char kInvalidResourceMsg[] =
78 "expects resource to be a VarHandleOp or function argument";
79 constexpr char kResourceNameArgAttr[] = "tf.resource_name";
80
81 // Checks if a function has only one block.
CheckSingleBlockFunction(FuncOp function)82 mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) {
83 if (!llvm::hasSingleElement(function)) {
84 return function.emitError()
85 << "expects function '" << function.getName()
86 << "' to have 1 block, got " << function.getBlocks().size();
87 }
88 return success();
89 }
90
91 // Collects names of users of a resource that are not `tf.ReadVariableOp` and
92 // not `tf.AssignVariableOp`.
GetCompositeResourceUserNames(Value resource)93 llvm::SmallSet<llvm::StringRef, 1> GetCompositeResourceUserNames(
94 Value resource) {
95 // SmallSet will use a vector when there is only one element and use std::set
96 // when there are more than one elements. This ensures that the operations in
97 // the error message are ordered.
98 llvm::SmallSet<llvm::StringRef, 1> composite_users;
99 for (Operation* user : resource.getUsers())
100 if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(user))
101 composite_users.insert(user->getName().getStringRef());
102
103 return composite_users;
104 }
105
106 // Checks that the only users of `tf.VarHandleOp` are
107 // `tf.ReadVariableOp` and `tf.AssignVariableOp`.
ValidateVarHandle(TF::VarHandleOp var_handle_op)108 mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) {
109 auto composite_ops = GetCompositeResourceUserNames(var_handle_op);
110 if (!composite_ops.empty())
111 return var_handle_op.emitOpError()
112 << "expects users to be 'tf.ReadVariableOp' or "
113 "'tf.AssignVariableOp', got ["
114 << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
115 << "]";
116
117 return success();
118 }
119
120 // Checks if resource argument has a valid resource subtype and its users are of
121 // `tf.ReadVariableOp` and `tf.AssignVariableOp` only.
ValidateResourceArgument(FuncOp function,BlockArgument resource_arg,TF::ResourceType resource_type)122 mlir::LogicalResult ValidateResourceArgument(FuncOp function,
123 BlockArgument resource_arg,
124 TF::ResourceType resource_type) {
125 if (resource_type.getSubtypes().size() != 1)
126 return function.emitError()
127 << "expects resource type of argument "
128 << resource_arg.getArgNumber() << " to have one subtype, got "
129 << resource_type;
130
131 auto composite_ops = GetCompositeResourceUserNames(resource_arg);
132 if (!composite_ops.empty())
133 return function.emitError()
134 << "expects users of resource argument "
135 << resource_arg.getArgNumber()
136 << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got ["
137 << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
138 << "]";
139
140 return success();
141 }
142
143 // Adds resource arguments for every unique (name) variable handle. Associated
144 // `tf.VarHandleOp` are removed from the function. Variable shared names are
145 // returned in `var_handle_shared_names` based on the ordering of added resource
146 // arguments.
PromoteVarHandlesToArguments(FuncOp function,bool add_validation,llvm::SmallVectorImpl<std::string> * var_handle_shared_names)147 mlir::LogicalResult PromoteVarHandlesToArguments(
148 FuncOp function, bool add_validation,
149 llvm::SmallVectorImpl<std::string>* var_handle_shared_names) {
150 Block& block = function.front();
151 auto func_type = function.getType();
152
153 auto func_arg_types = llvm::to_vector<4>(func_type.getInputs());
154 llvm::SmallDenseMap<llvm::StringRef, int> var_arg_index_by_name;
155 for (auto var_handle_op :
156 llvm::make_early_inc_range(block.getOps<TF::VarHandleOp>())) {
157 if (add_validation && failed(ValidateVarHandle(var_handle_op)))
158 return failure();
159
160 llvm::StringRef name = var_handle_op.shared_nameAttr().getValue();
161 auto it = var_arg_index_by_name.insert({name, func_arg_types.size()});
162 if (it.second) {
163 var_handle_shared_names->emplace_back(name);
164 auto resource_type = var_handle_op.resource().getType();
165 func_arg_types.push_back(resource_type);
166 var_handle_op.resource().replaceAllUsesWith(
167 block.addArgument(resource_type));
168 } else {
169 var_handle_op.resource().replaceAllUsesWith(
170 block.getArgument(it.first->getSecond()));
171 }
172 var_handle_op.erase();
173 }
174
175 if (!var_handle_shared_names->empty())
176 function.setType(FunctionType::get(function.getContext(), func_arg_types,
177 func_type.getResults()));
178
179 return success();
180 }
181
182 // Records the current live value for a resource variable and whether a read or
183 // write on the variable occurred.
184 struct ResourceInfo {
185 Value live_value = nullptr;
186 bool read = false;
187 bool write = false;
188 };
189
PromoteResourcesToArguments(FuncOp function,llvm::ArrayRef<std::string> var_handle_shared_names)190 LogicalResult PromoteResourcesToArguments(
191 FuncOp function, llvm::ArrayRef<std::string> var_handle_shared_names) {
192 Block& block = function.front();
193
194 auto return_op = llvm::dyn_cast_or_null<ReturnOp>(block.getTerminator());
195 if (!return_op)
196 return function.emitError() << "expects function '" << function.getName()
197 << "' to have a MLIR ReturnOp";
198
199 llvm::SmallVector<ResourceInfo, 4> resources(function.getNumArguments());
200 auto argument_types = llvm::to_vector<4>(function.getType().getInputs());
201 bool has_resources = false;
202 auto add_resource_argument = [&](BlockArgument arg,
203 TF::ResourceType resource_type) {
204 Type arg_type = resource_type.getSubtypes().front();
205 arg.setType(arg_type);
206 resources[arg.getArgNumber()].live_value = arg;
207 argument_types[arg.getArgNumber()] = arg_type;
208 has_resources = true;
209 };
210
211 // Loop through the non `tf.VarHandleOp` resource arguments in the function,
212 // validate its uses and subtype, and store a mapping from that argument to
213 // itself as the current live value.
214 auto func_args = function.getArguments().take_front(
215 function.getNumArguments() - var_handle_shared_names.size());
216 for (BlockArgument& func_arg : func_args) {
217 auto resource_type =
218 getElementTypeOrSelf(func_arg.getType()).dyn_cast<TF::ResourceType>();
219 if (!resource_type) continue;
220 if (failed(ValidateResourceArgument(function, func_arg, resource_type)))
221 return failure();
222
223 add_resource_argument(func_arg, resource_type);
224 }
225
226 // Loop through `tf.VarHandleOp` resource arguments in the function and store
227 // a mapping from that argument to itself as the current live value. No
228 // validations are necessary here as these arguments were validated prior to
229 // being added.
230 auto var_handle_args =
231 function.getArguments().take_back(var_handle_shared_names.size());
232 for (BlockArgument& var_handle_arg : var_handle_args) {
233 auto resource_type =
234 getElementTypeOrSelf(var_handle_arg.getType()).cast<TF::ResourceType>();
235 add_resource_argument(var_handle_arg, resource_type);
236 }
237
238 if (!has_resources) return success();
239
240 // We initially assign the argument for a resource as the live value for the
241 // resource. We then walk through the operations in the function in their
242 // lexical order, to update the live value for the resource when we see a
243 // store to the resource and replace reads of the resource with uses of its
244 // live value.
245 for (Operation& op : llvm::make_early_inc_range(block)) {
246 if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(&op)) {
247 if (auto func_arg = read_op.resource().dyn_cast<BlockArgument>()) {
248 if (func_arg.getOwner() != &block)
249 return read_op.emitOpError(kResourceFunctionMsg);
250
251 ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
252 resource_info.read = true;
253 read_op.value().replaceAllUsesWith(resource_info.live_value);
254 } else {
255 return read_op.emitOpError(kInvalidResourceMsg);
256 }
257
258 read_op.erase();
259 } else if (auto write_op = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
260 if (auto func_arg = write_op.resource().dyn_cast<BlockArgument>()) {
261 if (func_arg.getOwner() != &block)
262 return write_op.emitOpError(kResourceFunctionMsg);
263
264 ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
265 resource_info.write = true;
266 resource_info.live_value = write_op.value();
267 } else {
268 return read_op.emitOpError(kInvalidResourceMsg);
269 }
270
271 write_op.erase();
272 }
273 }
274
275 const int64_t num_results_before = function.getNumResults();
276 auto return_operands = llvm::to_vector<4>(return_op.getOperands());
277 auto result_types = llvm::to_vector<4>(return_op.getOperandTypes());
278 llvm::SmallVector<std::pair<int64_t, llvm::StringRef>, 4>
279 output_only_resources;
280 llvm::SmallVector<std::pair<int64_t, int64_t>, 4> input_output_alias;
281
282 // Collect new return values for variable writes and either (a) output-only
283 // resource attributes (if the resource is not promoted to an argument) or (b)
284 // mapping from resource input index to output alias (if the resource has been
285 // promoted to an argument). Resource arguments that were originally
286 // `tf.VarHandleOp` but not read are collected and then removed.
287 OpBuilder builder(return_op);
288 const int var_handles_start_idx =
289 function.getNumArguments() - var_handle_shared_names.size();
290 int new_argument_index = 0;
291 llvm::SmallVector<int, 4> argument_indices_to_remove;
292 for (auto resource_and_index : llvm::enumerate(resources)) {
293 const auto& resource = resource_and_index.value();
294 if (!resource.live_value) {
295 // Ignore non resource arguments.
296 ++new_argument_index;
297 continue;
298 }
299
300 const int64_t index = resource_and_index.index();
301 const bool is_var_handle = index >= var_handles_start_idx;
302 if (resource.write) {
303 if (!is_var_handle || resource.read) {
304 input_output_alias.push_back(
305 {new_argument_index, return_operands.size()});
306 } else if (is_var_handle) {
307 output_only_resources.push_back(
308 {return_operands.size(),
309 var_handle_shared_names[index - var_handles_start_idx]});
310 }
311 return_operands.push_back(resource.live_value);
312 result_types.push_back(resource.live_value.getType());
313 }
314
315 if (is_var_handle && !resource.read) {
316 assert(block.getArgument(index).getUses().empty());
317 argument_indices_to_remove.push_back(index);
318 } else {
319 if (is_var_handle) {
320 // Add resource_name attribute to VarHandleOp read.
321 function.setArgAttr(
322 new_argument_index, kResourceNameArgAttr,
323 builder.getStringAttr(
324 var_handle_shared_names[index - var_handles_start_idx]));
325 }
326 ++new_argument_index;
327 }
328 }
329
330 // Remove unread var handle arguments.
331 for (int argument_index_to_remove :
332 llvm::reverse(argument_indices_to_remove)) {
333 block.eraseArgument(argument_index_to_remove);
334 argument_types.erase(argument_types.begin() + argument_index_to_remove);
335 }
336
337 // Rewrite return if there are variable writes.
338 const int return_operands_size = return_operands.size();
339 if (return_operands_size > num_results_before) {
340 builder.create<ReturnOp>(return_op.getLoc(), return_operands);
341 return_op.erase();
342 }
343
344 // Update function argument and result types with new resource subtypes.
345 function.setType(builder.getFunctionType(argument_types, result_types));
346
347 // Add resource_name attribute to the output for the resources.
348 for (auto& resource : output_only_resources)
349 function.setResultAttr(resource.first, kResourceNameArgAttr,
350 builder.getStringAttr(resource.second));
351
352 // Add aliasing_output attribute to the input argument for the resources that
353 // are updated by the function.
354 for (auto& input_output : input_output_alias)
355 function.setArgAttr(input_output.first, "tf.aliasing_output",
356 builder.getI64IntegerAttr(input_output.second));
357
358 return success();
359 }
360
361 class PromoteResourcesToArgsPass
362 : public PassWrapper<PromoteResourcesToArgsPass, OperationPass<ModuleOp>> {
363 public:
364 void runOnOperation() override;
365 };
366
runOnOperation()367 void PromoteResourcesToArgsPass::runOnOperation() {
368 ModuleOp module = getOperation();
369 FuncOp main_func = module.lookupSymbol<FuncOp>("main");
370 if (!main_func) return;
371
372 // This routine should only be called when control flow operations are still
373 // represented with TF IfOp and WhileOp operations. In this case, there should
374 // be only one basic blocks in the MLIR representation.
375 if (failed(CheckSingleBlockFunction(main_func))) return signalPassFailure();
376
377 llvm::SmallVector<std::string, 4> var_handle_shared_names;
378 if (failed(ResourceLiftingForFunctionalControlFlow(main_func)) ||
379 failed(PromoteVarHandlesToArguments(main_func, /*add_validation=*/true,
380 &var_handle_shared_names)) ||
381 failed(PromoteResourcesToArguments(main_func, var_handle_shared_names)))
382 return signalPassFailure();
383 }
384
385 class PromoteVarHandlesToArgsPass
386 : public PassWrapper<PromoteVarHandlesToArgsPass, OperationPass<ModuleOp>> {
387 public:
388 void runOnOperation() override;
389 };
390
runOnOperation()391 void PromoteVarHandlesToArgsPass::runOnOperation() {
392 ModuleOp module = getOperation();
393 MLIRContext* context = module.getContext();
394 for (auto function : module.getOps<FuncOp>()) {
395 if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
396
397 llvm::SmallVector<std::string, 4> var_handle_shared_names;
398 (void)PromoteVarHandlesToArguments(function, /*add_validation=*/false,
399 &var_handle_shared_names);
400
401 // Add resource names for each `tf.VarHandleOp` that were promoted to
402 // resource arguments.
403 const int var_handle_args_offset =
404 function.getNumArguments() - var_handle_shared_names.size();
405 for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
406 function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
407 kResourceNameArgAttr,
408 StringAttr::get(context, var_name_and_index.value()));
409 }
410 }
411
412 } // namespace
413
CreatePromoteResourcesToArgsPass()414 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
415 return std::make_unique<PromoteResourcesToArgsPass>();
416 }
417
CreatePromoteVarHandlesToArgsPass()418 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() {
419 return std::make_unique<PromoteVarHandlesToArgsPass>();
420 }
421
422 static PassRegistration<PromoteResourcesToArgsPass> pass(
423 "tf-promote-resources-to-args",
424 "Promote resources reads/writes to function inputs/outputs.");
425
426 static PassRegistration<PromoteVarHandlesToArgsPass> var_handle_pass(
427 "tf-promote-var-handles-to-args",
428 "Promote tf.VarHandleOps to function arguments.");
429
430 } // namespace TF
431 } // namespace mlir
432