1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/saved_model/bundle_v2.h"
17 
18 #include "tensorflow/cc/saved_model/constants.h"
19 #include "tensorflow/core/lib/io/path.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/strcat.h"
23 #include "tensorflow/core/protobuf/saved_model.pb.h"
24 #include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
25 
26 namespace tensorflow {
27 
28 namespace {
29 
ReadSavedModelProto(const string & export_dir,SavedModel * saved_model_proto)30 Status ReadSavedModelProto(const string& export_dir,
31                            SavedModel* saved_model_proto) {
32   LOG(INFO) << "Reading SavedModel from: " << export_dir;
33 
34   const string saved_model_pb_path =
35       io::JoinPath(export_dir, kSavedModelFilenamePb);
36   if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
37     return ReadBinaryProto(Env::Default(), saved_model_pb_path,
38                            saved_model_proto);
39   }
40   const string saved_model_pbtxt_path =
41       io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
42   if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
43     return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
44                          saved_model_proto);
45   }
46   return Status(error::Code::NOT_FOUND,
47                 "Could not find SavedModel .pb or .pbtxt at supplied export "
48                 "directory path: " +
49                     export_dir);
50 }
51 
ReadSavedModelDebugInfoIfPresent(const string & export_dir,std::unique_ptr<GraphDebugInfo> * debug_info_proto)52 Status ReadSavedModelDebugInfoIfPresent(
53     const string& export_dir,
54     std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
55   LOG(INFO) << "Reading SavedModel debug info (if present) from: "
56             << export_dir;
57 
58   const string debug_info_pb_path =
59       io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb");
60   if (Env::Default()->FileExists(debug_info_pb_path).ok()) {
61     GraphDebugInfo debug_info;
62     TF_RETURN_IF_ERROR(
63         ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info));
64     *debug_info_proto =
65         absl::make_unique<GraphDebugInfo>(std::move(debug_info));
66   }
67   return Status::OK();
68 }
69 
ReadCheckpointObjectGraph(BundleReader * bundle_reader,TrackableObjectGraph * object_graph)70 Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
71                                  TrackableObjectGraph* object_graph) {
72   Tensor object_graph_tensor;
73   TF_RETURN_WITH_CONTEXT_IF_ERROR(
74       bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
75       "SavedModel checkpoint does not contain object graph.");
76   if (object_graph_tensor.dtype() != DT_STRING ||
77       object_graph_tensor.dims() != 0 ||
78       object_graph_tensor.NumElements() != 1) {
79     return Status(
80         error::Code::FAILED_PRECONDITION,
81         "SavedModel checkpoint object graph was not the correct type.");
82   }
83 
84   const tstring* object_graph_string = reinterpret_cast<const tstring*>(
85       object_graph_tensor.tensor_data().data());
86   if (!object_graph->ParseFromString(*object_graph_string)) {
87     return Status(
88         error::Code::FAILED_PRECONDITION,
89         "SavedModel checkpoint object graph could not be deserialized.");
90   }
91   return Status::OK();
92 }
93 
94 }  // namespace
95 
Load(const std::string & export_dir,SavedModelV2Bundle * const bundle)96 Status SavedModelV2Bundle::Load(const std::string& export_dir,
97                                 SavedModelV2Bundle* const bundle) {
98   SavedModel saved_model_proto;
99   TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
100 
101   // Load MetaGraphDef.
102   // In version 2 SavedModels, there is only one MetaGraphDef.
103   if (saved_model_proto.meta_graphs_size() != 1) {
104     return Status(
105         error::Code::INVALID_ARGUMENT,
106         strings::StrCat(
107             "SavedModelV2 should have exactly one MetaGraphDef but actually ",
108             "contains ", saved_model_proto.meta_graphs_size()));
109   }
110   bundle->meta_graph_def_ =
111       std::move(*saved_model_proto.mutable_meta_graphs(0));
112 
113   // Load GraphDebugInfo.
114   TF_RETURN_IF_ERROR(
115       ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
116 
117   const std::string variables_dir =
118       io::JoinPath(export_dir, kSavedModelVariablesDirectory);
119   if (!Env::Default()->FileExists(variables_dir).ok()) {
120     LOG(INFO)
121         << "No checkpoint found, assuming this is a program-only SavedModel";
122   } else {
123     // Load the variables checkpoint reader.
124     const std::string variables_prefix =
125         io::JoinPath(variables_dir, kSavedModelVariablesFilename);
126     bundle->variable_reader_.reset(
127         new BundleReader(Env::Default(), variables_prefix));
128     TF_RETURN_WITH_CONTEXT_IF_ERROR(
129         bundle->variable_reader_->status(),
130         "Unable to load SavedModel variables checkpoint from ",
131         variables_prefix);
132 
133     // Deserialize the object graph proto from the tensor bundle.
134     TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
135         bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
136   }
137 
138   return Status::OK();
139 }
140 
VisitObjectsToRestore(RestoreObjectsCallback callback)141 Status SavedModelV2Bundle::VisitObjectsToRestore(
142     RestoreObjectsCallback callback) {
143   if (saved_object_graph().nodes_size() == 0 ||
144       trackable_object_graph().nodes_size() == 0) {
145     return Status::OK();
146   }
147 
148   // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
149   // and descend to leaves. Note that the TrackableObjectGraph can have cycles
150   // (as can the SavedObjectGraph).
151   // This is detected and cycle edges are skipped.
152   const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
153   const TrackableObjectGraph::TrackableObject* root_trackable_object =
154       &trackable_object_graph().nodes(0);
155   absl::flat_hash_set<int> trackable_node_ids;
156   return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
157                                  std::string(), &trackable_node_ids,
158                                  std::move(callback));
159 }
160 
RecurseObjectsToRestore(const SavedObject * saved_object,int saved_object_node_id,const TrackableObjectGraph::TrackableObject * trackable_object,std::string object_name,absl::flat_hash_set<int> * seen_trackable_node_ids,RestoreObjectsCallback callback)161 Status SavedModelV2Bundle::RecurseObjectsToRestore(
162     const SavedObject* saved_object, int saved_object_node_id,
163     const TrackableObjectGraph::TrackableObject* trackable_object,
164     std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
165     RestoreObjectsCallback callback) {
166   // Callback if any attributes or slot variables.
167   // Note that the root is always excluded from the search (it can never
168   // be a restorable object). This matches some logic on the Python side.
169   if (saved_object_node_id != 0 &&
170       (trackable_object->attributes_size() > 0 ||
171        trackable_object->slot_variables_size() > 0)) {
172     TF_RETURN_WITH_CONTEXT_IF_ERROR(
173         callback(saved_object_node_id, *trackable_object), "Unable to restore ",
174         object_name);
175   }
176 
177   for (const auto& trackable_child_ref : trackable_object->children()) {
178     const auto& local_name = trackable_child_ref.local_name();
179 
180     // Compute the full child name.
181     std::string child_name;
182     if (object_name.empty()) {
183       child_name = local_name;
184     } else {
185       child_name = strings::StrCat(object_name, ".", local_name);
186     }
187 
188     // Descend down the trackable graph.
189     int trackable_child_node_id = trackable_child_ref.node_id();
190     if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
191       // Cycle or duplicate detected - ignore this branch.
192       continue;
193     }
194     if (trackable_child_node_id < 0 ||
195         trackable_child_node_id >= trackable_object_graph().nodes_size()) {
196       return Status(
197           errors::Code::FAILED_PRECONDITION,
198           strings::StrCat("Illegal trackable child node id for ", child_name));
199     }
200     const auto* trackable_child =
201         &trackable_object_graph().nodes(trackable_child_node_id);
202 
203     // Descend down the saved object graph.
204     int saved_child_node_id = -1;
205     const SavedObject* saved_child = nullptr;
206     for (const auto& saved_child_ref : saved_object->children()) {
207       if (saved_child_ref.local_name() == local_name) {
208         // Found.
209         saved_child_node_id = saved_child_ref.node_id();
210         if (saved_child_node_id >= 0 &&
211             saved_child_node_id < saved_object_graph().nodes_size()) {
212           saved_child = &saved_object_graph().nodes(saved_child_node_id);
213         }
214         break;
215       }
216     }
217 
218     if (!saved_child) {
219       return Status(
220           errors::Code::FAILED_PRECONDITION,
221           strings::StrCat("Could not find saved object to restore for ",
222                           child_name));
223     }
224 
225     TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
226         saved_child, saved_child_node_id, trackable_child, child_name,
227         seen_trackable_node_ids, callback));
228   }
229   return Status::OK();
230 }
231 
232 }  // namespace tensorflow
233