1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/c/eager/immediate_execution_context.h"
28 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
29 #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
30 #include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
31 #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
32 #include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h"
33 #include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
34 #include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h"
35 #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
36 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
37 #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
38 #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
39 #include "tensorflow/c/experimental/saved_model/core/signature_def_function.h"
40 #include "tensorflow/cc/saved_model/bundle_v2.h"
41 #include "tensorflow/cc/saved_model/constants.h"
42 #include "tensorflow/core/framework/attr_value.pb.h"
43 #include "tensorflow/core/framework/function.pb.h"
44 #include "tensorflow/core/framework/graph.pb.h"
45 #include "tensorflow/core/framework/node_def_util.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h"
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/lib/gtl/flatmap.h"
51 #include "tensorflow/core/lib/hash/hash.h"
52 #include "tensorflow/core/platform/casts.h"
53 #include "tensorflow/core/platform/errors.h"
54 #include "tensorflow/core/platform/logging.h"
55 #include "tensorflow/core/platform/macros.h"
56 #include "tensorflow/core/platform/path.h"
57 #include "tensorflow/core/platform/stringpiece.h"
58 #include "tensorflow/core/platform/tstring.h"
59 #include "tensorflow/core/protobuf/meta_graph.pb.h"
60 #include "tensorflow/core/protobuf/saved_model.pb.h"
61 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
62 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
63 
64 namespace tensorflow {
65 
66 // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
67 using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
68                                     StringPieceHasher>;
69 
70 // Maps from a functiondef's name to the corresponding "TFConcreteFunction"
71 using FlatTensorFunctionMap =
72     gtl::FlatMap<std::string, std::unique_ptr<FlatTensorFunction>>;
73 
74 namespace {
75 
76 const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(const TrackableObjectGraph::TrackableObject & trackable_object,absl::string_view name)77 FindSerializedTensorInTrackable(
78     const TrackableObjectGraph::TrackableObject& trackable_object,
79     absl::string_view name) {
80   for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
81     if (maybe_serialized_tensor.name() == name) {
82       return &maybe_serialized_tensor;
83     }
84   }
85   return nullptr;
86 }
87 
88 // This function reads the Checkpoint embedded in the SavedModel, and calls the
89 // appropriate Restore ops on each of the variables.
90 // Note(bmzhao): Conceptually, objects that contain checkpointable state
91 // implement the "_gather_saveables_for_checkpoint" method
92 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983
93 // which returns a dict of string key -> EITHER:
94 // 1. python callable (taking a checkpoint key) returning SaveableObject OR
95 // 2. variable (partitioned/resource/reference or otherwise)
96 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58.
97 // The string key becomes the "name" attribute of the SerializedTensor proto
98 // in the TrackableObjectGraph,
99 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26
100 // And the checkpoint_key is a globally unique string derived from this name:
101 // https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241
102 // SaveableObjects model the information needed to pass to the SaveV2/RestoreV2
103 // ops via their SaveSpec members
104 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21,
105 // which contain the "real" checkpoint keys into the TensorBundle SSTable.
106 // They also contain the logic needed to take the restored tensors from
107 // RestoreV2 and load them back into the "object" they came from via their
108 // overridden "restore" method:
109 // https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
RestoreCheckpoint(SavedModelV2Bundle * bundle,const RevivedObjects & revived_objects,const std::string & directory,ImmediateExecutionContext * context)110 Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
111                          const RevivedObjects& revived_objects,
112                          const std::string& directory,
113                          ImmediateExecutionContext* context) {
114   // TODO(bmzhao): Batch up all the restores into a single restore op per
115   // device, following logic in MultiDeviceSaver.
116   TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore(
117       [&revived_objects, &directory, context, bundle](
118           int node, const TrackableObjectGraph::TrackableObject& trackable) {
119         if (bundle->saved_object_graph().nodes(node).kind_case() !=
120             SavedObject::kVariable) {
121           // TODO(bmzhao): This requires using the newly added Save/Restore
122           // functions from
123           // https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
124           LOG(WARNING) << "Restoring non-variable objects has not been "
125                           "implemented yet. (Kind="
126                        << bundle->saved_object_graph().nodes(node).kind_case()
127                        << ")";
128           return Status::OK();
129         }
130 
131         Variable* variable = revived_objects.variables.at(node).get();
132 
133         // Restore the tensor's value from the checkpoint
134         const TrackableObjectGraph::TrackableObject::SerializedTensor*
135             attribute =
136                 FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE");
137         if (attribute == nullptr) {
138           return errors::FailedPrecondition(
139               "Could not find SerializedTensor with name VARIABLE_VALUE for "
140               "saved variable");
141         }
142 
143         const std::string& checkpoint_key = attribute->checkpoint_key();
144         if (!bundle->variable_reader()->Contains(checkpoint_key)) {
145           LOG(WARNING) << "No checkpoint entry found for " << checkpoint_key
146                        << ". Variable will be uninitialized.";
147           return Status();
148         }
149 
150         std::string variables_path_prefix =
151             io::JoinPath(directory, kSavedModelVariablesDirectory,
152                          kSavedModelVariablesFilename);
153         ImmediateTensorHandlePtr restored_output;
154         TF_RETURN_IF_ERROR(internal::SingleRestore(
155             context, variables_path_prefix, checkpoint_key, variable->dtype(),
156             &restored_output));
157 
158         // Assign the restored tensor's value to the variable
159         return variable->Assign(restored_output.get());
160       }));
161 
162   return Status();
163 }
164 
InitializeAllResources(const RevivedObjects & revived)165 Status InitializeAllResources(const RevivedObjects& revived) {
166   for (const auto& node_and_resource : revived.restored_resources) {
167     const RestoredResource& resource = node_and_resource.second;
168     TF_RETURN_IF_ERROR(resource.Initialize());
169   }
170   return Status();
171 }
172 
173 }  // namespace
174 
GetFunction(const std::string & function_path,ConcreteFunction ** function)175 Status TFSavedModelAPI::GetFunction(const std::string& function_path,
176                                     ConcreteFunction** function) {
177   absl::optional<int> node =
178       internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
179   if (!node.has_value()) {
180     return errors::NotFound("No saved object found at path ", function_path);
181   }
182 
183   *function = revived_objects_.concrete_functions.Find(*node);
184   if (*function == nullptr) {
185     return errors::NotFound("No function found at path ", function_path);
186   }
187 
188   return Status();
189 }
190 
GetSignatureDefFunction(const std::string & signature_def_key,SignatureDefFunction ** function)191 Status TFSavedModelAPI::GetSignatureDefFunction(
192     const std::string& signature_def_key, SignatureDefFunction** function) {
193   auto signatures_iter =
194       revived_objects_.signatures_map.find(signature_def_key);
195   if (signatures_iter == revived_objects_.signatures_map.end()) {
196     return errors::NotFound("No signature with key ", signature_def_key,
197                             " was found");
198   }
199   int node = signatures_iter->second;
200 
201   auto function_iter = revived_objects_.signature_def_functions.find(node);
202   if (function_iter == revived_objects_.signature_def_functions.end()) {
203     return errors::Internal(
204         "Unable to find SignatureDefFunction associated with key ",
205         signature_def_key, " despite key being valid.");
206   }
207 
208   *function = function_iter->second.get();
209   return Status();
210 }
211 
GetVariable(const std::string & variable_path,Variable ** variable)212 Status TFSavedModelAPI::GetVariable(const std::string& variable_path,
213                                     Variable** variable) {
214   absl::optional<int> node =
215       internal::FindNodeAtPath(variable_path, bundle_.saved_object_graph());
216   if (!node.has_value()) {
217     return errors::NotFound("No saved object found at path ", variable_path);
218   }
219 
220   auto variables_iter = revived_objects_.variables.find(*node);
221   if (variables_iter == revived_objects_.variables.end()) {
222     return errors::NotFound("No variable found at path ", variable_path);
223   }
224 
225   *variable = variables_iter->second.get();
226   return Status();
227 }
228 
TFSavedModelAPI(const std::string & directory,SavedModelV2Bundle bundle,RevivedObjects revived_objects)229 TFSavedModelAPI::TFSavedModelAPI(const std::string& directory,
230                                  SavedModelV2Bundle bundle,
231                                  RevivedObjects revived_objects)
232     : directory_(directory),
233       bundle_(std::move(bundle)),
234       revived_objects_(std::move(revived_objects)) {}
235 
Load(const std::string & directory,const absl::optional<std::unordered_set<std::string>> & tags,ImmediateExecutionContext * context,std::unique_ptr<TFSavedModelAPI> * out)236 Status TFSavedModelAPI::Load(
237     const std::string& directory,
238     const absl::optional<std::unordered_set<std::string>>& tags,
239     ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
240   // TODO(bmzhao): Add support for loading a TF1 SavedModel.
241   if (tags) {
242     return errors::Unimplemented(
243         "Loading saved models with explicit tags will be supported in the "
244         "future");
245   }
246 
247   SavedModelV2Bundle bundle;
248   TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle));
249 
250   // TODO(bmzhao): Mangle loaded function names so that different
251   // models loaded in the same runtime Context don't clobber eachother.
252   // This occurs in python here:
253   // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
254 
255   // For each node in the graph, we should initialize an object of the
256   // corresponding type. For objects that depend on the initialization of other
257   // objects (like functions which capture resources), we will initialize them
258   // later.
259   PartiallyRevivedObjects partially_revived_objects;
260   TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects(
261       bundle.meta_graph_def(), context, directory, &partially_revived_objects));
262 
263   RevivedObjects revived_objects;
264   TF_RETURN_IF_ERROR(partially_revived_objects.Build(
265       context, bundle.saved_object_graph(), &revived_objects));
266 
267   // Revive function library functions as concrete functions without captures.
268   // This is necessary because object graph functions may refer to functions
269   // _not_ in the object graph: A while loop, for example, will create two
270   // auxiliary `while_cond` and `while_body` functions that are only present in
271   // the graph def function library.
272   for (const FunctionDef& function :
273        bundle.meta_graph_def().graph_def().library().function()) {
274     std::unique_ptr<TFConcreteFunction> concrete_function;
275     TF_RETURN_IF_ERROR(TFConcreteFunction::Create(/*function_def=*/&function,
276                                                   /*captures=*/{},
277                                                   /*metadata=*/{},
278                                                   /*ctx=*/context,
279                                                   /*out=*/&concrete_function));
280     revived_objects.concrete_functions.Insert(std::move(concrete_function));
281   }
282 
283   TF_RETURN_IF_ERROR(
284       RestoreCheckpoint(&bundle, revived_objects, directory, context));
285 
286   TF_RETURN_IF_ERROR(InitializeAllResources(revived_objects));
287 
288   out->reset(new TFSavedModelAPI(directory, std::move(bundle),
289                                  std::move(revived_objects)));
290   return Status();
291 }
292 
293 }  // namespace tensorflow
294