1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <unordered_map>
21 #include <unordered_set>
22
23 #include "absl/strings/str_split.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
26 #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
27 #include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h"
28 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h"
29 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h"
30 #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
31 #include "tensorflow/c/tf_tensor_internal.h"
32 #include "tensorflow/cc/saved_model/loader_util.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/lib/gtl/flatmap.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/protobuf.h"
39 #include "tensorflow/core/platform/stringpiece.h"
40 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
41 #include "tensorflow/core/protobuf/struct.pb.h"
42 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
43
44 namespace tensorflow {
45 namespace internal {
46 namespace {
47
48 using StructuredValueDictEntry =
49 protobuf::MapPair<std::string, StructuredValue>;
50
51 // Maps from a Nodedef's name to its corresponding AttrValues, for a given
52 // Graphdef
53 using NodeAttrMap =
54 gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher>;
55
56 // Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
57 using FunctionDefMap = gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*,
58 StringPieceHasher>;
59
60 // Looks up a SavedConstant's associated tensorproto from the NodeAttrMap and
61 // returns a tensorflow::Constant.
ConstantFromSavedConstant(ImmediateExecutionContext * ctx,const tensorflow::SavedConstant & saved_constant,const NodeAttrMap & node_attr_map,std::unique_ptr<Constant> * output)62 Status ConstantFromSavedConstant(
63 ImmediateExecutionContext* ctx,
64 const tensorflow::SavedConstant& saved_constant,
65 const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
66 const std::string& const_op_name = saved_constant.operation();
67 const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
68 if (node_name_and_attrs == node_attr_map.end()) {
69 return errors::FailedPrecondition(
70 "Unable to find Const operation with name'", const_op_name,
71 "' in SavedModel graphdef");
72 }
73 const AttrValueMap* attrs = node_name_and_attrs->second;
74 const auto& attr_name_and_value = attrs->find("value");
75 if (attr_name_and_value == attrs->end()) {
76 return errors::FailedPrecondition("Unable to find Const operation '",
77 const_op_name, "'s value attribute");
78 }
79 const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
80 return internal::TensorProtoToConstant(ctx, tensor_proto, output);
81 }
82
83 // Finds the "signatures" object in the object graph, and fills a mapping of
84 // each signature's name to the corresponding function's node in the object
85 // graph.
GetSignaturesMap(const SavedObjectGraph & saved_objects,gtl::FlatMap<std::string,int> * signatures_map)86 Status GetSignaturesMap(const SavedObjectGraph& saved_objects,
87 gtl::FlatMap<std::string, int>* signatures_map) {
88 if (saved_objects.nodes().empty()) {
89 return errors::FailedPrecondition("Saved Object Graph was empty.");
90 }
91 const SavedObject& root = saved_objects.nodes(0);
92 const SavedObject* signatures = nullptr;
93 for (const auto& child : root.children()) {
94 if (child.local_name() == "signatures") {
95 if (child.node_id() >= saved_objects.nodes().size()) {
96 return errors::FailedPrecondition(
97 "Signature object had child node id ", child.node_id(),
98 " which exceeds the size of the set of nodes");
99 }
100 signatures = &saved_objects.nodes(child.node_id());
101 }
102 }
103
104 // Some basic sanity checks that this object is actually our "signatures" map
105 if (signatures == nullptr) {
106 // This is where the "signatures" attribute is always set:
107 // https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109
108 return errors::FailedPrecondition(
109 "SavedObjectGraph's root object must have a child 'signatures' object");
110 }
111 if (signatures->kind_case() != SavedObject::kUserObject) {
112 return errors::FailedPrecondition(
113 "Signatures must be a SavedObject of type UserObject.");
114 }
115 if (signatures->user_object().identifier() != "signature_map") {
116 // This is where the string comes from:
117 // https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220
118 return errors::FailedPrecondition(
119 "Signatures SavedObject must have identifier 'signature_map'.");
120 }
121
122 for (const auto& child : signatures->children()) {
123 (*signatures_map)[child.local_name()] = child.node_id();
124 }
125 return Status();
126 }
127
128 // Perform some basic sanity checks on SavedConcreteFunction's input and
129 // output signatures with respect to the corresponding FunctionDef's input
130 // and output args.
ValidateSavedFunctionCompatibleWithFunctionDef(const SavedConcreteFunction & saved_concrete_function,const FunctionDef * function_def)131 Status ValidateSavedFunctionCompatibleWithFunctionDef(
132 const SavedConcreteFunction& saved_concrete_function,
133 const FunctionDef* function_def) {
134 // tf.functions go through many transformations before becoming FunctionDefs
135 // 1. flatten user-provided inputs:
136 // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2671-L2675
137 // 2. convert user-provided inputs to tensors:
138 // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2687-L2688
139 // 3. filter any non-tensor, non-variable inputs:
140 // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1840-L1841
141 // 4. concatenate any captured inputs:
142 // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1912
143
144 // Since our API is limited to tf.functions annotated with input signatures,
145 // conditions 2 and 3 are trivially satisfied.
146 // We need to ensure that:
147 // flatten(input_signature).size() + captures.size() = fdef.signature().size()
148 // A concrete function's serialized "canonicalized_input_signature" comes
149 // from encoding its "structured_input_signature" field:
150 // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/saved_model/function_serialization.py#L70-L71
151 // The "structured_input_signature" is guaranteed to be a tuple of the python
152 // args, kwargs that correspond to the tf.function:
153 // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979
154
155 const std::string& name = function_def->signature().name();
156
157 const StructuredValue& input_signature =
158 saved_concrete_function.canonicalized_input_signature();
159 std::vector<const TensorSpecProto*> input_specs;
160 TF_RETURN_IF_ERROR(FlattenSignature(input_signature, &input_specs));
161 if (input_specs.size() + saved_concrete_function.bound_inputs_size() !=
162 function_def->signature().input_arg_size()) {
163 return errors::FailedPrecondition(
164 "FunctionDef ", name, " has ",
165 function_def->signature().input_arg_size(),
166 " inputs, but the SavedConcreteFunction has ", input_specs.size(),
167 " flattened user inputs and ",
168 saved_concrete_function.bound_inputs_size(), " captured inputs.");
169 }
170
171 const StructuredValue& output_signature =
172 saved_concrete_function.output_signature();
173 std::vector<const TensorSpecProto*> output_specs;
174 TF_RETURN_IF_ERROR(FlattenSignature(output_signature, &output_specs));
175 if (output_specs.size() != function_def->signature().output_arg_size()) {
176 return errors::FailedPrecondition(
177 "FunctionDef ", name, " has ",
178 function_def->signature().output_arg_size(),
179 " outputs, but the SavedConcreteFunction has ", output_specs.size(),
180 " flattened outputs.");
181 }
182
183 return Status();
184 }
185
ValidateSingleConcreteFunction(const SavedFunction & saved_function)186 Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) {
187 // We only allow loading functions that have an annotated input signature,
188 // which means there is 1:1 correspondence between tf.function
189 // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
190 // the same restriction that MLIR has:
191 // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
192 if (saved_function.concrete_functions_size() != 1) {
193 return errors::FailedPrecondition(
194 "Only tf.functions annotated with an input signature are supported "
195 "by SavedModelAPI. This means that there should only be a single "
196 "ConcreteFunction per tf.function");
197 }
198 return Status();
199 }
200
201 } // namespace
202
LoadSavedAsset(ImmediateExecutionContext * ctx,const SavedAsset & asset,const std::string & saved_model_dir,absl::Span<const AssetFileDef> assets,std::unique_ptr<Asset> * output)203 Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
204 const std::string& saved_model_dir,
205 absl::Span<const AssetFileDef> assets,
206 std::unique_ptr<Asset>* output) {
207 int asset_index = asset.asset_file_def_index();
208 if (asset_index >= assets.size()) {
209 return errors::FailedPrecondition(
210 "SavedAsset contained asset index ", asset_index,
211 " but AssetFileDef only contains ", assets.size(), " # of assets");
212 }
213 const std::string& asset_filename = assets[asset_index].filename();
214 return Asset::Create(ctx, saved_model_dir, asset_filename, output);
215 }
216
TensorProtoToConstant(ImmediateExecutionContext * ctx,const TensorProto & proto,std::unique_ptr<Constant> * output)217 Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
218 const TensorProto& proto,
219 std::unique_ptr<Constant>* output) {
220 tensorflow::Tensor tensor;
221 bool parse_result = tensor.FromProto(proto);
222 if (!parse_result) {
223 return errors::Internal("Failed to parse tensor from tensorproto");
224 }
225
226 TensorInterface tensor_interface(std::move(tensor));
227 return Constant::Create(ctx, &tensor_interface, output);
228 }
229
230 // This follows the python variable restoration logic:
231 // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
LoadSavedVariable(ImmediateExecutionContext * ctx,const SavedVariable & variable,std::unique_ptr<Variable> * output)232 Status LoadSavedVariable(ImmediateExecutionContext* ctx,
233 const SavedVariable& variable,
234 std::unique_ptr<Variable>* output) {
235 const std::string& name = variable.name();
236 tensorflow::TensorShape shape(variable.shape());
237 tensorflow::DataType dtype = variable.dtype();
238 std::vector<std::string> component_devices;
239
240 for (const auto& component :
241 variable.experimental_distributed_variable_components()) {
242 component_devices.push_back(component.device());
243 }
244
245 TF_RETURN_IF_ERROR(Variable::CreateUninitialized(
246 ctx, dtype, shape, name,
247 variable.device().empty() ? nullptr : variable.device().c_str(),
248 component_devices, output));
249 return Status();
250 }
251
LoadTFConcreteFunction(const SavedConcreteFunction & saved_concrete_function,const FunctionDef * function_def,const std::unordered_map<int,std::unique_ptr<TensorHandleConvertible>> & captured_objects,ImmediateExecutionContext * ctx,std::unique_ptr<TFConcreteFunction> * out)252 Status LoadTFConcreteFunction(
253 const SavedConcreteFunction& saved_concrete_function,
254 const FunctionDef* function_def,
255 const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>&
256 captured_objects,
257 ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out) {
258 TF_RETURN_IF_ERROR(ValidateSavedFunctionCompatibleWithFunctionDef(
259 saved_concrete_function, function_def));
260
261 // Copy over captures
262 std::vector<ImmediateExecutionTensorHandle*> captures;
263 captures.reserve(saved_concrete_function.bound_inputs_size());
264 for (int bound_input : saved_concrete_function.bound_inputs()) {
265 auto iter = captured_objects.find(bound_input);
266 if (iter == captured_objects.end()) {
267 return errors::FailedPrecondition("Failed to find bound_input ",
268 bound_input,
269 " for SavedConcreteFunction");
270 }
271 captures.push_back(iter->second->handle());
272 }
273
274 return TFConcreteFunction::Create(function_def, std::move(captures), {}, ctx,
275 out);
276 }
277
FlattenSignature(const StructuredValue & signature,std::vector<const TensorSpecProto * > * flattened_specs)278 Status FlattenSignature(const StructuredValue& signature,
279 std::vector<const TensorSpecProto*>* flattened_specs) {
280 // This follows the logic from
281 // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
282 switch (signature.kind_case()) {
283 case StructuredValue::kDictValue: {
284 // Dictionaries must be sorted in order of keys
285 const DictValue& dict = signature.dict_value();
286 std::vector<const StructuredValueDictEntry*> entries;
287 entries.reserve(dict.fields_size());
288 for (const auto& field : dict.fields()) {
289 entries.push_back(&field);
290 }
291
292 std::sort(entries.begin(), entries.end(),
293 [](const StructuredValueDictEntry* x,
294 const StructuredValueDictEntry* y) {
295 return x->first < y->first;
296 });
297
298 for (const auto& entry : entries) {
299 TF_RETURN_IF_ERROR(FlattenSignature(entry->second, flattened_specs));
300 }
301 return Status();
302 }
303 case StructuredValue::kTupleValue: {
304 const TupleValue& tuple = signature.tuple_value();
305 for (const StructuredValue& value : tuple.values()) {
306 TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
307 }
308 return Status();
309 }
310 case StructuredValue::kListValue: {
311 const ListValue& list = signature.list_value();
312 for (const StructuredValue& value : list.values()) {
313 TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
314 }
315 return Status();
316 }
317 case StructuredValue::kTensorSpecValue: {
318 flattened_specs->push_back(&signature.tensor_spec_value());
319 return Status();
320 }
321 case StructuredValue::kNoneValue: {
322 // Base case: do nothing.
323 // This arises, for example, as the top-level object of an output
324 // signature when there are no return values.
325 return Status();
326 }
327 default: {
328 return errors::Internal("Unhandled structured value kind ",
329 signature.kind_case());
330 }
331 }
332 }
333
FindNodeAtPath(StringPiece path,const SavedObjectGraph & object_graph)334 absl::optional<int> FindNodeAtPath(StringPiece path,
335 const SavedObjectGraph& object_graph) {
336 const auto& nodes = object_graph.nodes();
337 if (nodes.empty()) {
338 return absl::nullopt;
339 }
340
341 // Starting from the root, iterate through the saved object graph, matching
342 // object names as we go.
343 int node_id = 0;
344 const SavedObject* current_node = &nodes.Get(node_id);
345
346 for (absl::string_view object_name : absl::StrSplit(path, '.')) {
347 auto child_node_iter = std::find_if(
348 current_node->children().begin(), current_node->children().end(),
349 [object_name](
350 const TrackableObjectGraph::TrackableObject::ObjectReference& obj) {
351 return object_name == obj.local_name();
352 });
353 if (child_node_iter == current_node->children().end()) {
354 return absl::nullopt;
355 }
356
357 node_id = child_node_iter->node_id();
358 current_node = &nodes.Get(node_id);
359 }
360
361 return node_id;
362 }
363
NodeToAttrMap(const tensorflow::GraphDef & graphdef)364 gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap(
365 const tensorflow::GraphDef& graphdef) {
366 gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> result;
367 for (const tensorflow::NodeDef& node : graphdef.node()) {
368 result[node.name()] = &node.attr();
369 }
370 return result;
371 }
372
373 gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
FunctionNameToFunctionDefMap(const FunctionDefLibrary & library)374 FunctionNameToFunctionDefMap(const FunctionDefLibrary& library) {
375 gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
376 result;
377 for (const FunctionDef& function_def : library.function()) {
378 result[function_def.signature().name()] = &function_def;
379 }
380 return result;
381 }
382
PartiallyReviveSavedModelObjects(const MetaGraphDef & metagraph,ImmediateExecutionContext * context,const std::string & directory,PartiallyRevivedObjects * objects)383 Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
384 ImmediateExecutionContext* context,
385 const std::string& directory,
386 PartiallyRevivedObjects* objects) {
387 // This is needed to restore "Constant" nodes by looking up their
388 // "Value" attribute.
389 NodeAttrMap node_attr_map = NodeToAttrMap(metagraph.graph_def());
390
391 // These are needed for creating "Assets", by looking up their filenames.
392 std::vector<AssetFileDef> assets;
393 TF_RETURN_IF_ERROR(GetAssetFileDefs(metagraph, &assets));
394
395 // Signatures are needed for determining whether a function is a
396 // SignatureDefFunction or not.
397 gtl::FlatMap<std::string, int> signatures_map;
398 TF_RETURN_IF_ERROR(
399 GetSignaturesMap(metagraph.object_graph_def(), &signatures_map));
400
401 gtl::FlatMap<int, std::string> reversed_signatures_map;
402 reversed_signatures_map.reserve(signatures_map.size());
403 for (const auto& signature_key_and_node : signatures_map) {
404 reversed_signatures_map.emplace(signature_key_and_node.second,
405 signature_key_and_node.first);
406 }
407
408 // FunctionDefs are needed to help construct
409 // TFConcreteFunction/SignatureDefFunctions
410 const FunctionDefMap function_def_map =
411 internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
412
413 // Iterate through all the saved objects, restoring objects (if we can) as we
414 // go. For objects that dependencies on other objects (resources/functions),
415 // we partially initialize "builders" that correspond to their currently known
416 // state, and gradually fill them out in subsequent passes.
417 for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
418 const SavedObject& node = metagraph.object_graph_def().nodes(i);
419 if (node.kind_case() == SavedObject::kVariable) {
420 std::unique_ptr<Variable> variable;
421 TF_RETURN_IF_ERROR(
422 LoadSavedVariable(context, node.variable(), &variable));
423 objects->variables[i] = std::move(variable);
424 } else if (node.kind_case() == SavedObject::kConstant) {
425 std::unique_ptr<Constant> constant;
426 TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
427 node_attr_map, &constant));
428 objects->constants[i] = std::move(constant);
429 } else if (node.kind_case() == SavedObject::kAsset) {
430 std::unique_ptr<Asset> asset;
431 TF_RETURN_IF_ERROR(
432 LoadSavedAsset(context, node.asset(), directory, assets, &asset));
433 objects->assets[i] = std::move(asset);
434 } else if (node.kind_case() == SavedObject::kResource) {
435 RestoredResourceRevivalState resource_revival_state;
436 // We'll set the resource's functions in a subsequent pass, once we get
437 // all functions in a partially revived state.
438 resource_revival_state.device = node.resource().device();
439 objects->restored_resources[i] = std::move(resource_revival_state);
440 } else if (node.kind_case() == SavedObject::kFunction) {
441 // Get the SavedFunction node and validate it has a single concrete func.
442 const SavedFunction& saved_function = node.function();
443 TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function));
444
445 // Retrieve related function information.
446 const std::string& function_name = saved_function.concrete_functions(0);
447 const FunctionDef* function_def = function_def_map.at(function_name);
448 const SavedConcreteFunction& saved_concrete_func =
449 metagraph.object_graph_def().concrete_functions().at(function_name);
450 const FunctionSpec& function_spec = saved_function.function_spec();
451
452 // Construct either a SignatureDefFunctionBuilder or a
453 // ConcreteFunctionBuilder, depending on whether this node was a child
454 // of the "signatures" attribute from root object.
455 auto reverse_signature_iter = reversed_signatures_map.find(i);
456 if (reverse_signature_iter != reversed_signatures_map.end()) {
457 TFSignatureDefFunctionRevivalState func_revival_state;
458 func_revival_state.node_id = i;
459 func_revival_state.fdef = function_def;
460 func_revival_state.saved_concrete_func = &saved_concrete_func;
461 func_revival_state.signature_key = reverse_signature_iter->second;
462 objects->signature_def_functions[i] = std::move(func_revival_state);
463 } else {
464 TFConcreteFunctionRevivalState func_revival_state;
465 func_revival_state.node_id = i;
466 func_revival_state.fdef = function_def;
467 func_revival_state.saved_concrete_func = &saved_concrete_func;
468 func_revival_state.function_spec = &function_spec;
469 objects->concrete_functions[i] = std::move(func_revival_state);
470 }
471 } else if (node.kind_case() == SavedObject::kBareConcreteFunction) {
472 const SavedBareConcreteFunction& bare_cf = node.bare_concrete_function();
473
474 // Retrieve related function information.
475 const std::string& function_name = bare_cf.concrete_function_name();
476 const FunctionDef* function_def = function_def_map.at(function_name);
477 const SavedConcreteFunction& saved_concrete_func =
478 metagraph.object_graph_def().concrete_functions().at(function_name);
479
480 // Check whether this is a SignatureDefFunction, or not.
481 auto reverse_signature_iter = reversed_signatures_map.find(i);
482 if (reverse_signature_iter != reversed_signatures_map.end()) {
483 TFSignatureDefFunctionRevivalState func_revival_state;
484 func_revival_state.node_id = i;
485 func_revival_state.fdef = function_def;
486 func_revival_state.saved_concrete_func = &saved_concrete_func;
487 func_revival_state.signature_key = reverse_signature_iter->second;
488 objects->signature_def_functions[i] = std::move(func_revival_state);
489 } else {
490 TFConcreteFunctionRevivalState func_revival_state;
491 func_revival_state.node_id = i;
492 func_revival_state.fdef = function_def;
493 func_revival_state.saved_concrete_func = &saved_concrete_func;
494 objects->concrete_functions[i] = std::move(func_revival_state);
495 }
496 }
497 }
498
499 // Now that we've partially restored all functions, we can have resources
500 // point to them
501 for (auto& node_and_resource_revival_state : objects->restored_resources) {
502 int node_id = node_and_resource_revival_state.first;
503 const SavedObjectGraph& obj_graph = metagraph.object_graph_def();
504 const SavedObject& node = obj_graph.nodes(node_id);
505 RestoredResourceRevivalState& resource =
506 node_and_resource_revival_state.second;
507 for (const TrackableObjectGraph::TrackableObject::ObjectReference& child :
508 node.children()) {
509 int child_node_id = child.node_id();
510 // Note(bmzhao): The expected functions saved by a resource object are:
511 // "_create_resource", "_initialize", and "_destroy_resource".
512 // https://github.com/tensorflow/tensorflow/blob/ad66f588c1666ade8051feb42811fa27b285271c/tensorflow/python/training/tracking/tracking.py#L277-L281
513 if (child.local_name() == "_create_resource" &&
514 obj_graph.nodes(child.node_id()).kind_case() ==
515 SavedObject::kFunction) {
516 resource.create_resource = &objects->concrete_functions[child_node_id];
517 } else if (child.local_name() == "_initialize" &&
518 obj_graph.nodes(child.node_id()).kind_case() ==
519 SavedObject::kFunction) {
520 resource.initialize = &objects->concrete_functions[child_node_id];
521 } else if (child.local_name() == "_destroy_resource" &&
522 obj_graph.nodes(child.node_id()).kind_case() ==
523 SavedObject::kFunction) {
524 resource.destroy_resource = &objects->concrete_functions[child_node_id];
525 }
526 }
527 }
528
529 objects->signatures_map = std::move(signatures_map);
530
531 return Status();
532 }
533
534 } // namespace internal
535 } // namespace tensorflow
536