1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Types.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 
33 namespace mlir {
34 namespace TF {
35 namespace {
36 
37 // Location attribute.
38 constexpr StringRef kClassAttr = "_class";
39 constexpr StringRef kLocationPrefix = "loc:@";
40 
41 // A pass that converts readonly reference variables to the corresponding
42 // resource variables.
43 //
44 // It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).
45 //
46 // For the background, this pass is a part of hoisting VariableV2 ops by
47 // re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
48 //  can be done by the following passes:
49 //  - Capturing resource values into global tensors (importing saved model).
50 //  - Promoting VarHandle ops to function input/outputs.
51 //  - Freezing global tensor pass.
52 //
53 // This path assumes that all the VariableV2 ops is read-only via verifying the
54 // heuristic method that assumes that all the users of them is Identity op,
55 // fed directly.
56 class ConvertReadonlyReferenceVariablesToResourceVariablesPass
57     : public PassWrapper<
58           ConvertReadonlyReferenceVariablesToResourceVariablesPass,
59           FunctionPass> {
60  public:
61   void runOnFunction() override;
62 };
63 
64 // Parse node name from "_class" attribute.
GetNodeNameFromClassAttr(Operation * op)65 StringRef GetNodeNameFromClassAttr(Operation *op) {
66   ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr);
67   if (!classes_attr) {
68     // Attampt to parse "_class" from the IdentityOp that follows VariableV2.
69     // For read-only reference variables, IdentityOp should be the only user of
70     // VariableV2.
71     auto identity_op = op->getUsers().begin();
72     classes_attr = identity_op->getAttrOfType<ArrayAttr>(kClassAttr);
73     if (!classes_attr) {
74       op->emitOpError() << "has no '_class' attribute";
75       return StringRef();
76     }
77   }
78 
79   StringRef result;
80   for (Attribute class_attr : classes_attr) {
81     StringRef node_name = class_attr.cast<StringAttr>().getValue();
82     if (!node_name.startswith(kLocationPrefix)) {
83       continue;
84     }
85     if (!result.empty()) {
86       // Invalid case since there are multiple loc:@ attributes.
87       op->emitOpError()
88           << "expects only one named location in '_class' attribute, but got "
89           << classes_attr;
90       return StringRef();
91     }
92     result = node_name.drop_front(kLocationPrefix.size());
93   }
94   if (result.empty()) {
95     op->emitOpError() << "expects variable name in '_class' attribute, but got "
96                       << classes_attr;
97   }
98   return result;
99 }
100 
runOnFunction()101 void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
102   FuncOp func = getFunction();
103 
104   OpBuilder builder(func.getContext());
105   SmallVector<VariableV2Op, 4> variable_v2s_to_replace;
106 
107   // Checks all the VariableV2 ops is read-only via verifying the heuristic
108   // method that assumes that all the users of them is Identity op, feeded
109   // directly.
110   auto read_only_vars_fn = [&variable_v2s_to_replace](
111                                VariableV2Op variable_v2_op) {
112     if (variable_v2_op.getResult().use_empty()) {
113       // Erase the op when there is no user.
114       variable_v2_op.erase();
115       return mlir::WalkResult::advance();
116     }
117     if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op](
118                                                            Operation *user) {
119           if (!isa<IdentityOp>(user)) {
120             variable_v2_op.emitOpError()
121                 << "expects all users to be 'tf.Identity', but got user "
122                 << user->getName();
123             return false;
124           }
125           return true;
126         })) {
127       return mlir::WalkResult::interrupt();
128     }
129     variable_v2s_to_replace.push_back(variable_v2_op);
130     return mlir::WalkResult::advance();
131   };
132 
133   WalkResult walk_res = func.walk(read_only_vars_fn);
134   if (walk_res.wasInterrupted()) return signalPassFailure();
135 
136   for (VariableV2Op variable_v2_op : variable_v2s_to_replace) {
137     builder.setInsertionPoint(variable_v2_op);
138     ShapedType shaped_type =
139         variable_v2_op.getResult().getType().cast<ShapedType>();
140     TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
141     StringAttr device_attr =
142         variable_v2_op->getAttrOfType<StringAttr>("device");
143     if (!device_attr) device_attr = builder.getStringAttr("");
144     StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
145     if (variable_name.empty()) {
146       return signalPassFailure();
147     }
148     VarHandleOp var_handle_op = builder.create<VarHandleOp>(
149         variable_v2_op.getLoc(),
150         ArrayRef<Type>{RankedTensorType::get(
151             {}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type},
152                                       builder.getContext()))},
153         ArrayRef<Value>{},
154         ArrayRef<NamedAttribute>{
155             builder.getNamedAttr("device", device_attr),
156             builder.getNamedAttr("container", variable_v2_op.containerAttr()),
157             builder.getNamedAttr("shared_name",
158                                  builder.getStringAttr(variable_name))});
159     for (Operation *user :
160          make_early_inc_range(variable_v2_op.getResult().getUsers())) {
161       builder.setInsertionPoint(user);
162       ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
163           user->getLoc(), ArrayRef<Type>{tensor_type},
164           ArrayRef<Value>{var_handle_op});
165       user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
166       user->erase();
167     }
168     variable_v2_op.erase();
169   }
170 }
171 
172 }  // namespace
173 
174 std::unique_ptr<OperationPass<FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass()175 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
176   return std::make_unique<
177       ConvertReadonlyReferenceVariablesToResourceVariablesPass>();
178 }
179 
180 static PassRegistration<
181     ConvertReadonlyReferenceVariablesToResourceVariablesPass>
182     pass("tf-readonly-references-to-resources",
183          "Convert readonly reference variables to resource variables.");
184 
185 }  // namespace TF
186 
187 }  // namespace mlir
188