1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
17
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/c/eager/immediate_execution_context.h"
28 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
29 #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
30 #include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
31 #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
32 #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
33 #include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
34 #include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
35 #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
36 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
37 #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
38 #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
39 #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
40 #include "tensorflow/cc/saved_model/bundle_v2.h"
41 #include "tensorflow/cc/saved_model/constants.h"
42 #include "tensorflow/core/framework/attr_value.pb.h"
43 #include "tensorflow/core/framework/function.pb.h"
44 #include "tensorflow/core/framework/graph.pb.h"
45 #include "tensorflow/core/framework/node_def_util.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h"
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/lib/gtl/flatmap.h"
51 #include "tensorflow/core/lib/hash/hash.h"
52 #include "tensorflow/core/platform/casts.h"
53 #include "tensorflow/core/platform/errors.h"
54 #include "tensorflow/core/platform/logging.h"
55 #include "tensorflow/core/platform/macros.h"
56 #include "tensorflow/core/platform/path.h"
57 #include "tensorflow/core/platform/stringpiece.h"
58 #include "tensorflow/core/platform/tstring.h"
59 #include "tensorflow/core/protobuf/meta_graph.pb.h"
60 #include "tensorflow/core/protobuf/saved_model.pb.h"
61 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
62 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
63
64 namespace tensorflow {
65
66 // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
67 using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
68 StringPieceHasher>;
69
70 // Maps from a functiondef's name to the corresponding "TFConcreteFunction"
71 using FlatTensorFunctionMap =
72 gtl::FlatMap<std::string, std::unique_ptr<FlatTensorFunction>>;
73
74 namespace {
75
76 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,absl::string_view name)77 FindSerializedTensorInTrackable(
78 const TrackableObjectGraph::TrackableObject& trackable_object,
79 absl::string_view name) {
80 for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
81 if (maybe_serialized_tensor.name() == name) {
82 return &maybe_serialized_tensor;
83 }
84 }
85 return nullptr;
86 }
87
88 // This function reads the Checkpoint embedded in the SavedModel, and calls the
89 // appropriate Restore ops on each of the variables.
90 // Note(bmzhao): Conceptually, objects that contain checkpointable state
91 // implement the "_gather_saveables_for_checkpoint" method
92 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983
93 // which returns a dict of string key -> EITHER:
94 // 1. python callable (taking a checkpoint key) returning SaveableObject OR
95 // 2. variable (partitioned/resource/reference or otherwise)
96 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58.
97 // The string key becomes the "name" attribute of the SerializedTensor proto
98 // in the TrackableObjectGraph,
99 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26
100 // And the checkpoint_key is a globally unique string derived from this name:
101 // https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241
102 // SaveableObjects model the information needed to pass to the SaveV2/RestoreV2
103 // ops via their SaveSpec members
104 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21,
105 // which contain the "real" checkpoint keys into the TensorBundle SSTable.
106 // They also contain the logic needed to take the restored tensors from
107 // RestoreV2 and load them back into the "object" they came from via their
108 // overridden "restore" method:
109 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
RestoreCheckpoint(SavedModelV2Bundle * bundle,const RevivedObjects & revived_objects,const std::string & directory,ImmediateExecutionContext * context)110 Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
111 const RevivedObjects& revived_objects,
112 const std::string& directory,
113 ImmediateExecutionContext* context) {
114 // TODO(bmzhao): Batch up all the restores into a single restore op per
115 // device, following logic in MultiDeviceSaver.
116 TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore(
117 [&revived_objects, &directory, context, bundle](
118 int node, const TrackableObjectGraph::TrackableObject& trackable) {
119 if (bundle->saved_object_graph().nodes(node).kind_case() !=
120 SavedObject::kVariable) {
121 // TODO(bmzhao): This requires using the newly added Save/Restore
122 // functions from
123 // https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
124 LOG(WARNING) << "Restoring non-variable objects has not been "
125 "implemented yet. (Kind="
126 << bundle->saved_object_graph().nodes(node).kind_case()
127 << ")";
128 return Status::OK();
129 }
130
131 Variable* variable = revived_objects.variables.at(node).get();
132
133 // Restore the tensor's value from the checkpoint
134 const TrackableObjectGraph::TrackableObject::SerializedTensor*
135 attribute =
136 FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE");
137 if (attribute == nullptr) {
138 return errors::FailedPrecondition(
139 "Could not find SerializedTensor with name VARIABLE_VALUE for "
140 "saved variable");
141 }
142
143 const std::string& checkpoint_key = attribute->checkpoint_key();
144 if (!bundle->variable_reader()->Contains(checkpoint_key)) {
145 LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key
146 << ". Variable will be uninitialized.";
147 return Status();
148 }
149
150 std::string variables_path_prefix =
151 io::JoinPath(directory, kSavedModelVariablesDirectory,
152 kSavedModelVariablesFilename);
153 ImmediateTensorHandlePtr restored_output;
154 TF_RETURN_IF_ERROR(internal::SingleRestore(
155 context, variables_path_prefix, checkpoint_key, variable->dtype(),
156 &restored_output));
157
158 // Assign the restored tensor's value to the variable
159 return variable->Assign(restored_output.get());
160 }));
161
162 return Status();
163 }
164
InitializeAllResources(const RevivedObjects & revived)165 Status InitializeAllResources(const RevivedObjects& revived) {
166 for (const auto& node_and_resource : revived.restored_resources) {
167 const RestoredResource& resource = node_and_resource.second;
168 TF_RETURN_IF_ERROR(resource.Initialize());
169 }
170 return Status();
171 }
172
173 } // namespace
174
GetFunction(const std::string & function_path,ConcreteFunction ** function)175 Status TFSavedModelAPI::GetFunction(const std::string& function_path,
176 ConcreteFunction** function) {
177 absl::optional<int> node =
178 internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
179 if (!node.has_value()) {
180 return errors::NotFound("No saved object found at path ", function_path);
181 }
182
183 *function = revived_objects_.concrete_functions.Find(*node);
184 if (*function == nullptr) {
185 return errors::NotFound("No function found at path ", function_path);
186 }
187
188 return Status();
189 }
190
GetSignatureDefFunction(const std::string & signature_def_key,SignatureDefFunction ** function)191 Status TFSavedModelAPI::GetSignatureDefFunction(
192 const std::string& signature_def_key, SignatureDefFunction** function) {
193 auto signatures_iter =
194 revived_objects_.signatures_map.find(signature_def_key);
195 if (signatures_iter == revived_objects_.signatures_map.end()) {
196 return errors::NotFound("No signature with key ", signature_def_key,
197 " was found");
198 }
199 int node = signatures_iter->second;
200
201 auto function_iter = revived_objects_.signature_def_functions.find(node);
202 if (function_iter == revived_objects_.signature_def_functions.end()) {
203 return errors::Internal(
204 "Unable to find SignatureDefFunction associated with key ",
205 signature_def_key, " despite key being valid.");
206 }
207
208 *function = function_iter->second.get();
209 return Status();
210 }
211
GetVariable(const std::string & variable_path,Variable ** variable)212 Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
213 Variable** variable) {
214 absl::optional<int> node =
215 internal::FindNodeAtPath(variable_path, bundle_.saved_object_graph());
216 if (!node.has_value()) {
217 return errors::NotFound("No saved object found at path ", variable_path);
218 }
219
220 auto variables_iter = revived_objects_.variables.find(*node);
221 if (variables_iter == revived_objects_.variables.end()) {
222 return errors::NotFound("No variable found at path ", variable_path);
223 }
224
225 *variable = variables_iter->second.get();
226 return Status();
227 }
228
TFSavedModelAPI(const std::string & directory,SavedModelV2Bundle bundle,RevivedObjects revived_objects)229 TFSavedModelAPI::TFSavedModelAPI(const std::string& directory,
230 SavedModelV2Bundle bundle,
231 RevivedObjects revived_objects)
232 : directory_(directory),
233 bundle_(std::move(bundle)),
234 revived_objects_(std::move(revived_objects)) {}
235
Load(const std::string & directory,const absl::optional<std::unordered_set<std::string>> & tags,ImmediateExecutionContext * context,std::unique_ptr<TFSavedModelAPI> * out)236 Status TFSavedModelAPI::Load(
237 const std::string& directory,
238 const absl::optional<std::unordered_set<std::string>>& tags,
239 ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
240 // TODO(bmzhao): Add support for loading a TF1 SavedModel.
241 if (tags) {
242 return errors::Unimplemented(
243 "Loading saved models with explicit tags will be supported in the "
244 "future");
245 }
246
247 SavedModelV2Bundle bundle;
248 TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle));
249
250 // TODO(bmzhao): Mangle loaded function names so that different
251 // models loaded in the same runtime Context don't clobber eachother.
252 // This occurs in python here:
253 // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
254
255 // For each node in the graph, we should initialize an object of the
256 // corresponding type. For objects that depend on the initialization of other
257 // objects (like functions which capture resources), we will initialize them
258 // later.
259 PartiallyRevivedObjects partially_revived_objects;
260 TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects(
261 bundle.meta_graph_def(), context, directory, &partially_revived_objects));
262
263 RevivedObjects revived_objects;
264 TF_RETURN_IF_ERROR(partially_revived_objects.Build(
265 context, bundle.saved_object_graph(), &revived_objects));
266
267 // Revive function library functions as concrete functions without captures.
268 // This is necessary because object graph functions may refer to functions
269 // _not_ in the object graph: A while loop, for example, will create two
270 // auxiliary `while_cond` and `while_body` functions that are only present in
271 // the graph def function library.
272 for (const FunctionDef& function :
273 bundle.meta_graph_def().graph_def().library().function()) {
274 std::unique_ptr<TFConcreteFunction> concrete_function;
275 TF_RETURN_IF_ERROR(TFConcreteFunction::Create(/*function_def=*/&function,
276 /*captures=*/{},
277 /*metadata=*/{},
278 /*ctx=*/context,
279 /*out=*/&concrete_function));
280 revived_objects.concrete_functions.Insert(std::move(concrete_function));
281 }
282
283 TF_RETURN_IF_ERROR(
284 RestoreCheckpoint(&bundle, revived_objects, directory, context));
285
286 TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects));
287
288 out->reset(new TFSavedModelAPI(directory, std::move(bundle),
289 std::move(revived_objects)));
290 return Status();
291 }
292
293 } // namespace tensorflow
294