1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_ 17 #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_ 18 19 // Some internal utility functions for the SavedModelAPI, factored out into a 20 // separately unit-testable header. 21 22 #include <memory> 23 #include <unordered_map> 24 25 #include "absl/types/optional.h" 26 #include "absl/types/span.h" 27 #include "tensorflow/c/eager/immediate_execution_context.h" 28 #include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" 29 #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" 30 #include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" 31 #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" 32 #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" 33 #include "tensorflow/core/framework/node_def_util.h" 34 #include "tensorflow/core/framework/tensor.pb.h" 35 #include "tensorflow/core/lib/gtl/flatmap.h" 36 #include "tensorflow/core/lib/hash/hash.h" 37 #include "tensorflow/core/platform/status.h" 38 #include "tensorflow/core/platform/stringpiece.h" 39 #include "tensorflow/core/protobuf/meta_graph.pb.h" 40 #include "tensorflow/core/protobuf/saved_object_graph.pb.h" 41 #include "tensorflow/core/protobuf/struct.pb.h" 42 43 namespace tensorflow { 44 namespace internal { 45 46 // Load a TensorProto into a tensorflow::Constant. This is similar to the 47 // constant loading logic in python: 48 // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437 49 Status TensorProtoToConstant(ImmediateExecutionContext* ctx, 50 const TensorProto& proto, 51 std::unique_ptr<Constant>* output); 52 53 // Creates a tensorflow::Variable from a SavedVariable. This is similar to the 54 // logic in: 55 // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407 56 // Note that the caller **must assign a value** to the loaded variable. 57 Status LoadSavedVariable(ImmediateExecutionContext* ctx, 58 const SavedVariable& variable, 59 std::unique_ptr<Variable>* output); 60 61 Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset, 62 const std::string& saved_model_dir, 63 absl::Span<const AssetFileDef> assets, 64 std::unique_ptr<Asset>* output); 65 66 // Creates a TFConcreteFunction from a SavedConcreteFunction. 67 Status LoadTFConcreteFunction( 68 const SavedConcreteFunction& saved_concrete_function, 69 const FunctionDef* function_def, 70 const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>& 71 captured_objects, 72 ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out); 73 74 // Flattens `signature` into a vector of TensorSpecProto pointers back into 75 // `signature`. `signature` must outlive flattened_specs. `signature` must also 76 // be the input or output signature of a SavedConcreteFunction (i.e. "nested 77 // structures of tensorspecs"). 78 Status FlattenSignature(const StructuredValue& signature, 79 std::vector<const TensorSpecProto*>* flattened_specs); 80 81 // Find the node id in `object_graph` at location `path`. `path` must be 82 // a dot-delimited string of object names relative to the root object. If no 83 // object is found, returns absl::nullopt. 84 absl::optional<int> FindNodeAtPath(StringPiece path, 85 const SavedObjectGraph& object_graph); 86 87 // Maps each node in `graphdef` to its corresponding Attribute Map. 88 // Callers must ensure that `graphdef` outlives the returned map. 89 gtl::FlatMap<StringPiece, const AttrValueMap*, StringPieceHasher> NodeToAttrMap( 90 const tensorflow::GraphDef& graphdef); 91 92 // Maps the name of each FunctionDef in `library` to its corresponding 93 // FunctionDef. Callers must ensure `library` outlives the returned map. 94 gtl::FlatMap<StringPiece, const tensorflow::FunctionDef*, StringPieceHasher> 95 FunctionNameToFunctionDefMap(const FunctionDefLibrary& library); 96 97 // Walks through the SavedObjectGraph in metagraph, and restores all nodes 98 // (except "UserDefinedObjects") with their corresponding type in 99 // "PartiallyRevivedObjects". 100 Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph, 101 ImmediateExecutionContext* context, 102 const std::string& directory, 103 PartiallyRevivedObjects* objects); 104 105 } // namespace internal 106 } // namespace tensorflow 107 108 #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_ 109