1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/Types.h" // from @llvm-project
27 #include "mlir/IR/Value.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32
33 namespace mlir {
34 namespace TF {
35 namespace {
36
37 // Location attribute.
38 constexpr StringRef kClassAttr = "_class";
39 constexpr StringRef kLocationPrefix = "loc:@";
40
41 // A pass that converts readonly reference variables to the corresponding
42 // resource variables.
43 //
44 // It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).
45 //
46 // For the background, this pass is a part of hoisting VariableV2 ops by
47 // re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
48 // can be done by the following passes:
49 // - Capturing resource values into global tensors (importing saved model).
50 // - Promoting VarHandle ops to function input/outputs.
51 // - Freezing global tensor pass.
52 //
53 // This path assumes that all the VariableV2 ops is read-only via verifying the
54 // heuristic method that assumes that all the users of them is Identity op,
55 // fed directly.
56 class ConvertReadonlyReferenceVariablesToResourceVariablesPass
57 : public PassWrapper<
58 ConvertReadonlyReferenceVariablesToResourceVariablesPass,
59 FunctionPass> {
60 public:
61 void runOnFunction() override;
62 };
63
64 // Parse node name from "_class" attribute.
GetNodeNameFromClassAttr(Operation * op)65 StringRef GetNodeNameFromClassAttr(Operation *op) {
66 ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr);
67 if (!classes_attr) {
68 // Attampt to parse "_class" from the IdentityOp that follows VariableV2.
69 // For read-only reference variables, IdentityOp should be the only user of
70 // VariableV2.
71 auto identity_op = op->getUsers().begin();
72 classes_attr = identity_op->getAttrOfType<ArrayAttr>(kClassAttr);
73 if (!classes_attr) {
74 op->emitOpError() << "has no '_class' attribute";
75 return StringRef();
76 }
77 }
78
79 StringRef result;
80 for (Attribute class_attr : classes_attr) {
81 StringRef node_name = class_attr.cast<StringAttr>().getValue();
82 if (!node_name.startswith(kLocationPrefix)) {
83 continue;
84 }
85 if (!result.empty()) {
86 // Invalid case since there are multiple loc:@ attributes.
87 op->emitOpError()
88 << "expects only one named location in '_class' attribute, but got "
89 << classes_attr;
90 return StringRef();
91 }
92 result = node_name.drop_front(kLocationPrefix.size());
93 }
94 if (result.empty()) {
95 op->emitOpError() << "expects variable name in '_class' attribute, but got "
96 << classes_attr;
97 }
98 return result;
99 }
100
runOnFunction()101 void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
102 FuncOp func = getFunction();
103
104 OpBuilder builder(func.getContext());
105 SmallVector<VariableV2Op, 4> variable_v2s_to_replace;
106
107 // Checks all the VariableV2 ops is read-only via verifying the heuristic
108 // method that assumes that all the users of them is Identity op, feeded
109 // directly.
110 auto read_only_vars_fn = [&variable_v2s_to_replace](
111 VariableV2Op variable_v2_op) {
112 if (variable_v2_op.getResult().use_empty()) {
113 // Erase the op when there is no user.
114 variable_v2_op.erase();
115 return mlir::WalkResult::advance();
116 }
117 if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op](
118 Operation *user) {
119 if (!isa<IdentityOp>(user)) {
120 variable_v2_op.emitOpError()
121 << "expects all users to be 'tf.Identity', but got user "
122 << user->getName();
123 return false;
124 }
125 return true;
126 })) {
127 return mlir::WalkResult::interrupt();
128 }
129 variable_v2s_to_replace.push_back(variable_v2_op);
130 return mlir::WalkResult::advance();
131 };
132
133 WalkResult walk_res = func.walk(read_only_vars_fn);
134 if (walk_res.wasInterrupted()) return signalPassFailure();
135
136 for (VariableV2Op variable_v2_op : variable_v2s_to_replace) {
137 builder.setInsertionPoint(variable_v2_op);
138 ShapedType shaped_type =
139 variable_v2_op.getResult().getType().cast<ShapedType>();
140 TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
141 StringAttr device_attr =
142 variable_v2_op->getAttrOfType<StringAttr>("device");
143 if (!device_attr) device_attr = builder.getStringAttr("");
144 StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
145 if (variable_name.empty()) {
146 return signalPassFailure();
147 }
148 VarHandleOp var_handle_op = builder.create<VarHandleOp>(
149 variable_v2_op.getLoc(),
150 ArrayRef<Type>{RankedTensorType::get(
151 {}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type},
152 builder.getContext()))},
153 ArrayRef<Value>{},
154 ArrayRef<NamedAttribute>{
155 builder.getNamedAttr("device", device_attr),
156 builder.getNamedAttr("container", variable_v2_op.containerAttr()),
157 builder.getNamedAttr("shared_name",
158 builder.getStringAttr(variable_name))});
159 for (Operation *user :
160 make_early_inc_range(variable_v2_op.getResult().getUsers())) {
161 builder.setInsertionPoint(user);
162 ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
163 user->getLoc(), ArrayRef<Type>{tensor_type},
164 ArrayRef<Value>{var_handle_op});
165 user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
166 user->erase();
167 }
168 variable_v2_op.erase();
169 }
170 }
171
172 } // namespace
173
174 std::unique_ptr<OperationPass<FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass()175 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
176 return std::make_unique<
177 ConvertReadonlyReferenceVariablesToResourceVariablesPass>();
178 }
179
180 static PassRegistration<
181 ConvertReadonlyReferenceVariablesToResourceVariablesPass>
182 pass("tf-readonly-references-to-resources",
183 "Convert readonly reference variables to resource variables.");
184
185 } // namespace TF
186
187 } // namespace mlir
188