1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <unordered_map>
21 #include <unordered_set>
22 
23 #include "absl/strings/str_split.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
26 #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
27 #include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
28 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
29 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
30 #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
31 #include "tensorflow/c/tf_tensor_internal.h"
32 #include "tensorflow/cc/saved_model/loader_util.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/lib/gtl/flatmap.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/protobuf.h"
39 #include "tensorflow/core/platform/stringpiece.h"
40 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
41 #include "tensorflow/core/protobuf/struct.pb.h"
42 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
43 
44 namespace tensorflow {
45 namespace internal {
46 namespace {
47 
48 using StructuredValueDictEntry =
49     protobuf::MapPair<std::string, StructuredValue>;
50 
51 // Maps from a Nodedef's name to its corresponding AttrValues, for a given
52 // Graphdef
53 using NodeAttrMap =
54     gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher>;
55 
56 // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
57 using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
58                                     StringPieceHasher>;
59 
60 // Looks up a SavedConstant's associated tensorproto from the NodeAttrMap and
61 // returns a tensorflow::Constant.
ConstantFromSavedConstant(ImmediateExecutionContext * ctx,const tensorflow::SavedConstant & saved_constant,const NodeAttrMap & node_attr_map,std::unique_ptr<Constant> * output)62 Status ConstantFromSavedConstant(
63     ImmediateExecutionContext* ctx,
64     const tensorflow::SavedConstant& saved_constant,
65     const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
66   const std::string& const_op_name = saved_constant.operation();
67   const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
68   if (node_name_and_attrs == node_attr_map.end()) {
69     return errors::FailedPrecondition(
70         "Unable to find Const operation with name'", const_op_name,
71         "' in SavedModel graphdef");
72   }
73   const AttrValueMap* attrs = node_name_and_attrs->second;
74   const auto& attr_name_and_value = attrs->find("value");
75   if (attr_name_and_value == attrs->end()) {
76     return errors::FailedPrecondition("Unable to find Const operation '",
77                                       const_op_name, "'s value attribute");
78   }
79   const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
80   return internal::TensorProtoToConstant(ctx, tensor_proto, output);
81 }
82 
83 // Finds the "signatures" object in the object graph, and fills a mapping of
84 // each signature's name to the corresponding function's node in the object
85 // graph.
GetSignaturesMap(const SavedObjectGraph & saved_objects,gtl::FlatMap<std::string,int> * signatures_map)86 Status GetSignaturesMap(const SavedObjectGraph& saved_objects,
87                         gtl::FlatMap<std::string, int>* signatures_map) {
88   if (saved_objects.nodes().empty()) {
89     return errors::FailedPrecondition("Saved Object Graph was empty.");
90   }
91   const SavedObject& root = saved_objects.nodes(0);
92   const SavedObject* signatures = nullptr;
93   for (const auto& child : root.children()) {
94     if (child.local_name() == "signatures") {
95       if (child.node_id() >= saved_objects.nodes().size()) {
96         return errors::FailedPrecondition(
97             "Signature object had child node id ", child.node_id(),
98             " which exceeds the size of the set of nodes");
99       }
100       signatures = &saved_objects.nodes(child.node_id());
101     }
102   }
103 
104   // Some basic sanity checks that this object is actually our "signatures" map
105   if (signatures == nullptr) {
106     // This is where the "signatures" attribute is always set:
107     // https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109
108     return errors::FailedPrecondition(
109         "SavedObjectGraph's root object must have a child 'signatures' object");
110   }
111   if (signatures->kind_case() != SavedObject::kUserObject) {
112     return errors::FailedPrecondition(
113         "Signatures must be a SavedObject of type UserObject.");
114   }
115   if (signatures->user_object().identifier() != "signature_map") {
116     // This is where the string comes from:
117     // https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220
118     return errors::FailedPrecondition(
119         "Signatures SavedObject must have identifier 'signature_map'.");
120   }
121 
122   for (const auto& child : signatures->children()) {
123     (*signatures_map)[child.local_name()] = child.node_id();
124   }
125   return Status();
126 }
127 
128 // Perform some basic sanity checks on SavedConcreteFunction's input and
129 // output signatures with respect to the corresponding FunctionDef's input
130 // and output args.
ValidateSavedFunctionCompatibleWithFunctionDef(const SavedConcreteFunction & saved_concrete_function,const FunctionDef * function_def)131 Status ValidateSavedFunctionCompatibleWithFunctionDef(
132     const SavedConcreteFunction& saved_concrete_function,
133     const FunctionDef* function_def) {
134   // tf.functions go through many transformations before becoming FunctionDefs
135   // 1. flatten user-provided inputs:
136   // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2671-L2675
137   // 2. convert user-provided inputs to tensors:
138   // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2687-L2688
139   // 3. filter any non-tensor, non-variable inputs:
140   // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1840-L1841
141   // 4. concatenate any captured inputs:
142   // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1912
143 
144   // Since our API is limited to tf.functions annotated with input signatures,
145   // conditions 2 and 3 are trivially satisfied.
146   // We need to ensure that:
147   // flatten(input_signature).size() + captures.size() = fdef.signature().size()
148   // A concrete function's serialized "canonicalized_input_signature" comes
149   // from encoding its "structured_input_signature" field:
150   // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/saved_model/function_serialization.py#L70-L71
151   // The "structured_input_signature" is guaranteed to be a tuple of the python
152   // args, kwargs that correspond to the tf.function:
153   // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979
154 
155   const std::string& name = function_def->signature().name();
156 
157   const StructuredValue& input_signature =
158       saved_concrete_function.canonicalized_input_signature();
159   std::vector<const TensorSpecProto*> input_specs;
160   TF_RETURN_IF_ERROR(FlattenSignature(input_signature, &input_specs));
161   if (input_specs.size() + saved_concrete_function.bound_inputs_size() !=
162       function_def->signature().input_arg_size()) {
163     return errors::FailedPrecondition(
164         "FunctionDef ", name, " has ",
165         function_def->signature().input_arg_size(),
166         " inputs, but the SavedConcreteFunction has ", input_specs.size(),
167         " flattened user inputs and ",
168         saved_concrete_function.bound_inputs_size(), " captured inputs.");
169   }
170 
171   const StructuredValue& output_signature =
172       saved_concrete_function.output_signature();
173   std::vector<const TensorSpecProto*> output_specs;
174   TF_RETURN_IF_ERROR(FlattenSignature(output_signature, &output_specs));
175   if (output_specs.size() != function_def->signature().output_arg_size()) {
176     return errors::FailedPrecondition(
177         "FunctionDef ", name, " has ",
178         function_def->signature().output_arg_size(),
179         " outputs, but the SavedConcreteFunction has ", output_specs.size(),
180         " flattened outputs.");
181   }
182 
183   return Status();
184 }
185 
ValidateSingleConcreteFunction(const SavedFunction & saved_function)186 Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) {
187   // We only allow loading functions that have an annotated input signature,
188   // which means there is 1:1 correspondence between tf.function
189   // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
190   // the same restriction that MLIR has:
191   // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
192   if (saved_function.concrete_functions_size() != 1) {
193     return errors::FailedPrecondition(
194         "Only tf.functions annotated with an input signature are supported "
195         "by SavedModelAPI. This means that there should only be a single "
196         "ConcreteFunction per tf.function");
197   }
198   return Status();
199 }
200 
201 }  // namespace
202 
LoadSavedAsset(ImmediateExecutionContext * ctx,const SavedAsset & asset,const std::string & saved_model_dir,absl::Span<const AssetFileDef> assets,std::unique_ptr<Asset> * output)203 Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
204                       const std::string& saved_model_dir,
205                       absl::Span<const AssetFileDef> assets,
206                       std::unique_ptr<Asset>* output) {
207   int asset_index = asset.asset_file_def_index();
208   if (asset_index >= assets.size()) {
209     return errors::FailedPrecondition(
210         "SavedAsset contained asset index ", asset_index,
211         " but AssetFileDef only contains ", assets.size(), " # of assets");
212   }
213   const std::string& asset_filename = assets[asset_index].filename();
214   return Asset::Create(ctx, saved_model_dir, asset_filename, output);
215 }
216 
TensorProtoToConstant(ImmediateExecutionContext * ctx,const TensorProto & proto,std::unique_ptr<Constant> * output)217 Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
218                              const TensorProto& proto,
219                              std::unique_ptr<Constant>* output) {
220   tensorflow::Tensor tensor;
221   bool parse_result = tensor.FromProto(proto);
222   if (!parse_result) {
223     return errors::Internal("Failed to parse tensor from tensorproto");
224   }
225 
226   TensorInterface tensor_interface(std::move(tensor));
227   return Constant::Create(ctx, &tensor_interface, output);
228 }
229 
230 // This follows the python variable restoration logic:
231 // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
LoadSavedVariable(ImmediateExecutionContext * ctx,const SavedVariable & variable,std::unique_ptr<Variable> * output)232 Status LoadSavedVariable(ImmediateExecutionContext* ctx,
233                          const SavedVariable& variable,
234                          std::unique_ptr<Variable>* output) {
235   const std::string& name = variable.name();
236   tensorflow::TensorShape shape(variable.shape());
237   tensorflow::DataType dtype = variable.dtype();
238   std::vector<std::string> component_devices;
239 
240   for (const auto& component :
241        variable.experimental_distributed_variable_components()) {
242     component_devices.push_back(component.device());
243   }
244 
245   TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
246       ctx, dtype, shape, name,
247       variable.device().empty() ? nullptr : variable.device().c_str(),
248       component_devices, output));
249   return Status();
250 }
251 
LoadTFConcreteFunction(const SavedConcreteFunction & saved_concrete_function,const FunctionDef * function_def,const std::unordered_map<int,std::unique_ptr<TensorHandleConvertible>> & captured_objects,ImmediateExecutionContext * ctx,std::unique_ptr<TFConcreteFunction> * out)252 Status LoadTFConcreteFunction(
253     const SavedConcreteFunction& saved_concrete_function,
254     const FunctionDef* function_def,
255     const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>&
256         captured_objects,
257     ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out) {
258   TF_RETURN_IF_ERROR(ValidateSavedFunctionCompatibleWithFunctionDef(
259       saved_concrete_function, function_def));
260 
261   // Copy over captures
262   std::vector<ImmediateExecutionTensorHandle*> captures;
263   captures.reserve(saved_concrete_function.bound_inputs_size());
264   for (int bound_input : saved_concrete_function.bound_inputs()) {
265     auto iter = captured_objects.find(bound_input);
266     if (iter == captured_objects.end()) {
267       return errors::FailedPrecondition("Failed to find bound_input ",
268                                         bound_input,
269                                         " for SavedConcreteFunction");
270     }
271     captures.push_back(iter->second->handle());
272   }
273 
274   return TFConcreteFunction::Create(function_def, std::move(captures), {}, ctx,
275                                     out);
276 }
277 
FlattenSignature(const StructuredValue & signature,std::vector<const TensorSpecProto * > * flattened_specs)278 Status FlattenSignature(const StructuredValue& signature,
279                         std::vector<const TensorSpecProto*>* flattened_specs) {
280   // This follows the logic from
281   // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
282   switch (signature.kind_case()) {
283     case StructuredValue::kDictValue: {
284       // Dictionaries must be sorted in order of keys
285       const DictValue& dict = signature.dict_value();
286       std::vector<const StructuredValueDictEntry*> entries;
287       entries.reserve(dict.fields_size());
288       for (const auto& field : dict.fields()) {
289         entries.push_back(&field);
290       }
291 
292       std::sort(entries.begin(), entries.end(),
293                 [](const StructuredValueDictEntry* x,
294                    const StructuredValueDictEntry* y) {
295                   return x->first < y->first;
296                 });
297 
298       for (const auto& entry : entries) {
299         TF_RETURN_IF_ERROR(FlattenSignature(entry->second, flattened_specs));
300       }
301       return Status();
302     }
303     case StructuredValue::kTupleValue: {
304       const TupleValue& tuple = signature.tuple_value();
305       for (const StructuredValue& value : tuple.values()) {
306         TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
307       }
308       return Status();
309     }
310     case StructuredValue::kListValue: {
311       const ListValue& list = signature.list_value();
312       for (const StructuredValue& value : list.values()) {
313         TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
314       }
315       return Status();
316     }
317     case StructuredValue::kTensorSpecValue: {
318       flattened_specs->push_back(&signature.tensor_spec_value());
319       return Status();
320     }
321     case StructuredValue::kNoneValue: {
322       // Base case: do nothing.
323       // This arises, for example, as the top-level object of an output
324       // signature when there are no return values.
325       return Status();
326     }
327     default: {
328       return errors::Internal("Unhandled structured value kind ",
329                               signature.kind_case());
330     }
331   }
332 }
333 
FindNodeAtPath(StringPiece path,const SavedObjectGraph & object_graph)334 absl::optional<int> FindNodeAtPath(StringPiece path,
335                                    const SavedObjectGraph& object_graph) {
336   const auto& nodes = object_graph.nodes();
337   if (nodes.empty()) {
338     return absl::nullopt;
339   }
340 
341   // Starting from the root, iterate through the saved object graph, matching
342   // object names as we go.
343   int node_id = 0;
344   const SavedObject* current_node = &nodes.Get(node_id);
345 
346   for (absl::string_view object_name : absl::StrSplit(path, '.')) {
347     auto child_node_iter = std::find_if(
348         current_node->children().begin(), current_node->children().end(),
349         [object_name](
350             const TrackableObjectGraph::TrackableObject::ObjectReference& obj) {
351           return object_name == obj.local_name();
352         });
353     if (child_node_iter == current_node->children().end()) {
354       return absl::nullopt;
355     }
356 
357     node_id = child_node_iter->node_id();
358     current_node = &nodes.Get(node_id);
359   }
360 
361   return node_id;
362 }
363 
NodeToAttrMap(const tensorflow::GraphDef & graphdef)364 gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap(
365     const tensorflow::GraphDef& graphdef) {
366   gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> result;
367   for (const tensorflow::NodeDef& node : graphdef.node()) {
368     result[node.name()] = &node.attr();
369   }
370   return result;
371 }
372 
373 gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
FunctionNameToFunctionDefMap(const FunctionDefLibrary & library)374 FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) {
375   gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
376       result;
377   for (const FunctionDef& function_def : library.function()) {
378     result[function_def.signature().name()] = &function_def;
379   }
380   return result;
381 }
382 
PartiallyReviveSavedModelObjects(const MetaGraphDef & metagraph,ImmediateExecutionContext * context,const std::string & directory,PartiallyRevivedObjects * objects)383 Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
384                                         ImmediateExecutionContext* context,
385                                         const std::string& directory,
386                                         PartiallyRevivedObjects* objects) {
387   // This is needed to restore "Constant" nodes by looking up their
388   // "Value" attribute.
389   NodeAttrMap node_attr_map = NodeToAttrMap(metagraph.graph_def());
390 
391   // These are needed for creating "Assets", by looking up their filenames.
392   std::vector<AssetFileDef> assets;
393   TF_RETURN_IF_ERROR(GetAssetFileDefs(metagraph, &assets));
394 
395   // Signatures are needed for determining whether a function is a
396   // SignatureDefFunction or not.
397   gtl::FlatMap<std::string, int> signatures_map;
398   TF_RETURN_IF_ERROR(
399       GetSignaturesMap(metagraph.object_graph_def(), &signatures_map));
400 
401   gtl::FlatMap<int, std::string> reversed_signatures_map;
402   reversed_signatures_map.reserve(signatures_map.size());
403   for (const auto& signature_key_and_node : signatures_map) {
404     reversed_signatures_map.emplace(signature_key_and_node.second,
405                                     signature_key_and_node.first);
406   }
407 
408   // FunctionDefs are needed to help construct
409   // TFConcreteFunction/SignatureDefFunctions
410   const FunctionDefMap function_def_map =
411       internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
412 
413   // Iterate through all the saved objects, restoring objects (if we can) as we
414   // go. For objects that dependencies on other objects (resources/functions),
415   // we partially initialize "builders" that correspond to their currently known
416   // state, and gradually fill them out in subsequent passes.
417   for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
418     const SavedObject& node = metagraph.object_graph_def().nodes(i);
419     if (node.kind_case() == SavedObject::kVariable) {
420       std::unique_ptr<Variable> variable;
421       TF_RETURN_IF_ERROR(
422           LoadSavedVariable(context, node.variable(), &variable));
423       objects->variables[i] = std::move(variable);
424     } else if (node.kind_case() == SavedObject::kConstant) {
425       std::unique_ptr<Constant> constant;
426       TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
427                                                    node_attr_map, &constant));
428       objects->constants[i] = std::move(constant);
429     } else if (node.kind_case() == SavedObject::kAsset) {
430       std::unique_ptr<Asset> asset;
431       TF_RETURN_IF_ERROR(
432           LoadSavedAsset(context, node.asset(), directory, assets, &asset));
433       objects->assets[i] = std::move(asset);
434     } else if (node.kind_case() == SavedObject::kResource) {
435       RestoredResourceRevivalState resource_revival_state;
436       // We'll set the resource's functions in a subsequent pass, once we get
437       // all functions in a partially revived state.
438       resource_revival_state.device = node.resource().device();
439       objects->restored_resources[i] = std::move(resource_revival_state);
440     } else if (node.kind_case() == SavedObject::kFunction) {
441       // Get the SavedFunction node and validate it has a single concrete func.
442       const SavedFunction& saved_function = node.function();
443       TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function));
444 
445       // Retrieve related function information.
446       const std::string& function_name = saved_function.concrete_functions(0);
447       const FunctionDef* function_def = function_def_map.at(function_name);
448       const SavedConcreteFunction& saved_concrete_func =
449           metagraph.object_graph_def().concrete_functions().at(function_name);
450       const FunctionSpec& function_spec = saved_function.function_spec();
451 
452       // Construct either a SignatureDefFunctionBuilder or a
453       // ConcreteFunctionBuilder, depending on whether this node was a child
454       // of the "signatures" attribute from root object.
455       auto reverse_signature_iter = reversed_signatures_map.find(i);
456       if (reverse_signature_iter != reversed_signatures_map.end()) {
457         TFSignatureDefFunctionRevivalState func_revival_state;
458         func_revival_state.node_id = i;
459         func_revival_state.fdef = function_def;
460         func_revival_state.saved_concrete_func = &saved_concrete_func;
461         func_revival_state.signature_key = reverse_signature_iter->second;
462         objects->signature_def_functions[i] = std::move(func_revival_state);
463       } else {
464         TFConcreteFunctionRevivalState func_revival_state;
465         func_revival_state.node_id = i;
466         func_revival_state.fdef = function_def;
467         func_revival_state.saved_concrete_func = &saved_concrete_func;
468         func_revival_state.function_spec = &function_spec;
469         objects->concrete_functions[i] = std::move(func_revival_state);
470       }
471     } else if (node.kind_case() == SavedObject::kBareConcreteFunction) {
472       const SavedBareConcreteFunction& bare_cf = node.bare_concrete_function();
473 
474       // Retrieve related function information.
475       const std::string& function_name = bare_cf.concrete_function_name();
476       const FunctionDef* function_def = function_def_map.at(function_name);
477       const SavedConcreteFunction& saved_concrete_func =
478           metagraph.object_graph_def().concrete_functions().at(function_name);
479 
480       // Check whether this is a SignatureDefFunction, or not.
481       auto reverse_signature_iter = reversed_signatures_map.find(i);
482       if (reverse_signature_iter != reversed_signatures_map.end()) {
483         TFSignatureDefFunctionRevivalState func_revival_state;
484         func_revival_state.node_id = i;
485         func_revival_state.fdef = function_def;
486         func_revival_state.saved_concrete_func = &saved_concrete_func;
487         func_revival_state.signature_key = reverse_signature_iter->second;
488         objects->signature_def_functions[i] = std::move(func_revival_state);
489       } else {
490         TFConcreteFunctionRevivalState func_revival_state;
491         func_revival_state.node_id = i;
492         func_revival_state.fdef = function_def;
493         func_revival_state.saved_concrete_func = &saved_concrete_func;
494         objects->concrete_functions[i] = std::move(func_revival_state);
495       }
496     }
497   }
498 
499   // Now that we've partially restored all functions, we can have resources
500   // point to them
501   for (auto& node_and_resource_revival_state : objects->restored_resources) {
502     int node_id = node_and_resource_revival_state.first;
503     const SavedObjectGraph& obj_graph = metagraph.object_graph_def();
504     const SavedObject& node = obj_graph.nodes(node_id);
505     RestoredResourceRevivalState& resource =
506         node_and_resource_revival_state.second;
507     for (const TrackableObjectGraph::TrackableObject::ObjectReference& child :
508          node.children()) {
509       int child_node_id = child.node_id();
510       // Note(bmzhao): The expected functions saved by a resource object are:
511       // "_create_resource", "_initialize", and "_destroy_resource".
512       // https://github.com/tensorflow/tensorflow/blob/ad66f588c1666ade8051feb42811fa27b285271c/tensorflow/python/training/tracking/tracking.py#L277-L281
513       if (child.local_name() == "_create_resource" &&
514           obj_graph.nodes(child.node_id()).kind_case() ==
515               SavedObject::kFunction) {
516         resource.create_resource = &objects->concrete_functions[child_node_id];
517       } else if (child.local_name() == "_initialize" &&
518                  obj_graph.nodes(child.node_id()).kind_case() ==
519                      SavedObject::kFunction) {
520         resource.initialize = &objects->concrete_functions[child_node_id];
521       } else if (child.local_name() == "_destroy_resource" &&
522                  obj_graph.nodes(child.node_id()).kind_case() ==
523                      SavedObject::kFunction) {
524         resource.destroy_resource = &objects->concrete_functions[child_node_id];
525       }
526     }
527   }
528 
529   objects->signatures_map = std::move(signatures_map);
530 
531   return Status();
532 }
533 
534 }  // namespace internal
535 }  // namespace tensorflow
536