1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
17 #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
18 
19 // Some internal utility functions for the SavedModelAPI, factored out into a
20 // separately unit-testable header.
21 
22 #include <memory>
23 #include <unordered_map>
24 
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/c/eager/immediate_execution_context.h"
28 #include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h"
29 #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
30 #include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h"
31 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
32 #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/tensor.pb.h"
35 #include "tensorflow/core/lib/gtl/flatmap.h"
36 #include "tensorflow/core/lib/hash/hash.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/platform/stringpiece.h"
39 #include "tensorflow/core/protobuf/meta_graph.pb.h"
40 #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
41 #include "tensorflow/core/protobuf/struct.pb.h"
42 
43 namespace tensorflow {
44 namespace internal {
45 
46 // Load a TensorProto into a tensorflow::Constant. This is similar to the
47 // constant loading logic in python:
48 // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437
49 Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
50                              const TensorProto& proto,
51                              std::unique_ptr<Constant>* output);
52 
53 // Creates a tensorflow::Variable from a SavedVariable. This is similar to the
54 // logic in:
55 // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
56 // Note that the caller **must assign a value** to the loaded variable.
57 Status LoadSavedVariable(ImmediateExecutionContext* ctx,
58                          const SavedVariable& variable,
59                          std::unique_ptr<Variable>* output);
60 
61 Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset,
62                       const std::string& saved_model_dir,
63                       absl::Span<const AssetFileDef> assets,
64                       std::unique_ptr<Asset>* output);
65 
66 // Creates a TFConcreteFunction from a SavedConcreteFunction.
67 Status LoadTFConcreteFunction(
68     const SavedConcreteFunction& saved_concrete_function,
69     const FunctionDef* function_def,
70     const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>&
71         captured_objects,
72     ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out);
73 
74 // Flattens `signature` into a vector of TensorSpecProto pointers back into
75 // `signature`. `signature` must outlive flattened_specs. `signature` must also
76 // be the input or output signature of a SavedConcreteFunction (i.e. "nested
77 // structures of tensorspecs").
78 Status FlattenSignature(const StructuredValue& signature,
79                         std::vector<const TensorSpecProto*>* flattened_specs);
80 
81 // Find the node id in `object_graph` at location `path`. `path` must be
82 // a dot-delimited string of object names relative to the root object. If no
83 // object is found, returns absl::nullopt.
84 absl::optional<int> FindNodeAtPath(StringPiece path,
85                                    const SavedObjectGraph& object_graph);
86 
87 // Maps each node in `graphdef` to its corresponding Attribute Map.
88 // Callers must ensure that `graphdef` outlives the returned map.
89 gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap(
90     const tensorflow::GraphDef& graphdef);
91 
92 // Maps the name of each FunctionDef in `library` to its corresponding
93 // FunctionDef. Callers must ensure `library` outlives the returned map.
94 gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher>
95 FunctionNameToFunctionDefMap(const FunctionDefLibrary& library);
96 
97 // Walks through the SavedObjectGraph in metagraph, and restores all nodes
98 // (except "UserDefinedObjects") with their corresponding type in
99 // "PartiallyRevivedObjects".
100 Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph,
101                                         ImmediateExecutionContext* context,
102                                         const std::string& directory,
103                                         PartiallyRevivedObjects* objects);
104 
105 }  // namespace internal
106 }  // namespace tensorflow
107 
108 #endif  // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
109