1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/contrib/session_bundle/bundle_shim.h"
17
18 #include "tensorflow/cc/saved_model/loader.h"
19 #include "tensorflow/cc/saved_model/signature_constants.h"
20 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
21 #include "tensorflow/contrib/session_bundle/session_bundle.h"
22 #include "tensorflow/contrib/session_bundle/signature.h"
23 #include "tensorflow/core/graph/graph_constructor.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
28 #include "tensorflow/core/public/session.h"
29 #include "tensorflow/core/public/session_options.h"
30
31 namespace tensorflow {
32 namespace serving {
33 namespace {
34 ///////////////////////////////////////////////////////////////////////////////
35 // Helper functions to check Signature type.
36
IsClassificationSignature(const Signature & signature)37 bool IsClassificationSignature(const Signature& signature) {
38 return signature.type_case() == Signature::kClassificationSignature;
39 }
40
IsRegressionSignature(const Signature & signature)41 bool IsRegressionSignature(const Signature& signature) {
42 return signature.type_case() == Signature::kRegressionSignature;
43 }
44
45 ///////////////////////////////////////////////////////////////////////////////
46 // Helper functions to build `Classification`, `Regression` and `Predict`
47 // SignatureDefs.
48
BuildRegressionSignatureDef(const RegressionSignature & regression_signature,const std::unordered_map<string,DataType> & tensor_name_to_dtype)49 SignatureDef BuildRegressionSignatureDef(
50 const RegressionSignature& regression_signature,
51 const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
52 SignatureDef signature_def;
53 signature_def.set_method_name(kRegressMethodName);
54 internal::AddInputToSignatureDef(regression_signature.input().tensor_name(),
55 tensor_name_to_dtype, kRegressInputs,
56 &signature_def);
57 internal::AddOutputToSignatureDef(regression_signature.output().tensor_name(),
58 tensor_name_to_dtype, kRegressOutputs,
59 &signature_def);
60 return signature_def;
61 }
62
BuildClassificationSignatureDef(const ClassificationSignature & classification_signature,const std::unordered_map<string,DataType> & tensor_name_to_dtype)63 SignatureDef BuildClassificationSignatureDef(
64 const ClassificationSignature& classification_signature,
65 const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
66 SignatureDef signature_def;
67 signature_def.set_method_name(kClassifyMethodName);
68 internal::AddInputToSignatureDef(
69 classification_signature.input().tensor_name(), tensor_name_to_dtype,
70 kClassifyInputs, &signature_def);
71 internal::AddOutputToSignatureDef(
72 classification_signature.classes().tensor_name(), tensor_name_to_dtype,
73 kClassifyOutputClasses, &signature_def);
74 internal::AddOutputToSignatureDef(
75 classification_signature.scores().tensor_name(), tensor_name_to_dtype,
76 kClassifyOutputScores, &signature_def);
77 return signature_def;
78 }
79
MaybeBuildPredictSignatureDef(const std::unordered_map<string,DataType> & tensor_name_to_dtype,MetaGraphDef * meta_graph_def)80 Status MaybeBuildPredictSignatureDef(
81 const std::unordered_map<string, DataType>& tensor_name_to_dtype,
82 MetaGraphDef* meta_graph_def) {
83 Signature input_signature, output_signature;
84 // Ensure that named signatures corresponding to `inputs` and `outputs` keys
85 // exist.
86 if (!GetNamedSignature(kPredictInputs, *meta_graph_def, &input_signature)
87 .ok() ||
88 !GetNamedSignature(kPredictOutputs, *meta_graph_def, &output_signature)
89 .ok()) {
90 return Status(error::Code::INVALID_ARGUMENT,
91 "Named signatures can only be up-converted if entries "
92 "corresponding to both `inputs` and `outputs` exist.");
93 }
94 // Ensure the `inputs` and `outputs` named signatures are generic signatures.
95 if (input_signature.type_case() != Signature::TypeCase::kGenericSignature ||
96 output_signature.type_case() != Signature::TypeCase::kGenericSignature) {
97 return Status(error::Code::INVALID_ARGUMENT,
98 "Named signatures corresponding to `inputs` and `outputs` "
99 "can only be up-converted if they are GenericSignatures.");
100 }
101 SignatureDef signature_def;
102 signature_def.set_method_name(kPredictMethodName);
103 // Add map entries from the `inputs` generic signature to the input map in the
104 // signature def.
105 for (const auto& map_entry : input_signature.generic_signature().map()) {
106 internal::AddInputToSignatureDef(map_entry.second.tensor_name(),
107 tensor_name_to_dtype, map_entry.first,
108 &signature_def);
109 }
110 // Add map entries from the `outputs` generic signature to the output map in
111 // the signature def.
112 for (const auto& map_entry : output_signature.generic_signature().map()) {
113 internal::AddOutputToSignatureDef(map_entry.second.tensor_name(),
114 tensor_name_to_dtype, map_entry.first,
115 &signature_def);
116 }
117 // Add the constructed signature def to the signature def map of the meta
118 // graph def. Use the default key if it isn't already in use.
119 const bool already_has_default_signature =
120 meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) !=
121 meta_graph_def->signature_def().end();
122 const string signature_def_key =
123 already_has_default_signature
124 ? strings::StrCat(kDefaultServingSignatureDefKey, "_from_named")
125 : kDefaultServingSignatureDefKey;
126 (*meta_graph_def->mutable_signature_def())[signature_def_key] = signature_def;
127 return Status::OK();
128 }
129
LoadSavedModelFromLegacySessionBundlePath(const SessionOptions & session_options,const RunOptions & run_options,const StringPiece session_bundle_export_dir,SavedModelBundle * saved_model_bundle)130 Status LoadSavedModelFromLegacySessionBundlePath(
131 const SessionOptions& session_options, const RunOptions& run_options,
132 const StringPiece session_bundle_export_dir,
133 SavedModelBundle* saved_model_bundle) {
134 if (session_bundle_export_dir.empty()) {
135 return Status(error::Code::NOT_FOUND, "Export directory path is empty.");
136 }
137 if (!IsPossibleExportDirectory(session_bundle_export_dir)) {
138 return Status(
139 error::Code::NOT_FOUND,
140 "Export directory does not contain a valid SessionBundle export.");
141 }
142
143 // Build the session-bundle.
144 SessionBundle session_bundle;
145 TF_RETURN_IF_ERROR(LoadSessionBundleFromPathUsingRunOptions(
146 session_options, run_options, session_bundle_export_dir,
147 &session_bundle));
148
149 // Convert the session-bundle to a saved-model-bundle.
150 return internal::ConvertSessionBundleToSavedModelBundle(session_bundle,
151 saved_model_bundle);
152 }
153
154 ///////////////////////////////////////////////////////////////////////////////
155 // Helper functions to convert `Default` and `Named` signatures to
156 // SignatureDefs.
157
158 // Up-conversion of default signatures is supported for classification and
159 // regression.
ConvertDefaultSignatureToSignatureDef(const Signatures & signatures,const std::unordered_map<string,DataType> & tensor_name_to_dtype,MetaGraphDef * meta_graph_def)160 Status ConvertDefaultSignatureToSignatureDef(
161 const Signatures& signatures,
162 const std::unordered_map<string, DataType>& tensor_name_to_dtype,
163 MetaGraphDef* meta_graph_def) {
164 if (!signatures.has_default_signature()) {
165 return Status::OK();
166 }
167 const bool already_has_default_signature =
168 meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) !=
169 meta_graph_def->signature_def().end();
170 if (already_has_default_signature) {
171 return Status(error::Code::ALREADY_EXISTS,
172 strings::StrCat(
173 "Default signature cannot be up-converted since ",
174 kDefaultServingSignatureDefKey, " key already exists."));
175 }
176 const Signature& signature = signatures.default_signature();
177 if (IsRegressionSignature(signature)) {
178 (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
179 BuildRegressionSignatureDef(signature.regression_signature(),
180 tensor_name_to_dtype);
181 } else if (IsClassificationSignature(signature)) {
182 (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
183 BuildClassificationSignatureDef(signature.classification_signature(),
184 tensor_name_to_dtype);
185 } else {
186 LOG(WARNING) << "Default signature up-conversion to SignatureDef is only "
187 "supported for `Classification` and `Regression`. Could "
188 "not up-convert signature: "
189 << signature.DebugString()
190 << ". (If using SessionRun with the SessionBundle export "
191 "format please ignore this warning.)";
192 }
193 return Status::OK();
194 }
195
ConvertNamedSignaturesToSignatureDef(const Signatures & signatures,const std::unordered_map<string,DataType> & tensor_name_to_dtype,MetaGraphDef * meta_graph_def)196 Status ConvertNamedSignaturesToSignatureDef(
197 const Signatures& signatures,
198 const std::unordered_map<string, DataType>& tensor_name_to_dtype,
199 MetaGraphDef* meta_graph_def) {
200 if (signatures.named_signatures().empty()) {
201 return Status::OK();
202 }
203 // Check for a Predict signature for up-conversion.
204 Status predict_signature_def_status =
205 MaybeBuildPredictSignatureDef(tensor_name_to_dtype, meta_graph_def);
206 for (const auto& it_named_signature : signatures.named_signatures()) {
207 const string key = it_named_signature.first;
208 // If a Predict SignatureDef was successfully constructed, skip the entries
209 // corresponding to `inputs` and `outputs`.
210 if (predict_signature_def_status.ok()) {
211 if (key == kPredictInputs || key == kPredictOutputs) {
212 continue;
213 }
214 }
215 const Signature signature = it_named_signature.second;
216 if (IsRegressionSignature(signature)) {
217 (*meta_graph_def->mutable_signature_def())[key] =
218 BuildRegressionSignatureDef(signature.regression_signature(),
219 tensor_name_to_dtype);
220 } else if (IsClassificationSignature(signature)) {
221 (*meta_graph_def->mutable_signature_def())[key] =
222 BuildClassificationSignatureDef(signature.classification_signature(),
223 tensor_name_to_dtype);
224 } else {
225 LOG(WARNING)
226 << "Named signature up-conversion to SignatureDef is only supported "
227 "for `Classification`, `Regression` or if two `GenericSignatures` "
228 "signatures called `inputs` and `outputs` exist, corresponding "
229 "to the `Prediction` API. Could not up-convert signature: "
230 << signature.DebugString();
231 }
232 }
233 return Status::OK();
234 }
235
236 } // namespace
237
238 namespace internal {
239 ///////////////////////////////////////////////////////////////////////////////
240 // Helper functions to populate SignatureDef fields.
241
242 // Adds an entry to the `inputs` map of the supplied SignatureDef.
AddInputToSignatureDef(const string & tensor_name,const std::unordered_map<string,DataType> & tensor_name_to_dtype,const string & input_key,SignatureDef * signature_def)243 void AddInputToSignatureDef(
244 const string& tensor_name,
245 const std::unordered_map<string, DataType>& tensor_name_to_dtype,
246 const string& input_key, SignatureDef* signature_def) {
247 if (tensor_name.empty()) {
248 LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
249 "SignatureDef inputs.";
250 return;
251 }
252 // Extract the tensor-name in case the supplied string is a tensor-reference.
253 // Example: Extract "x" from "x:0".
254 std::size_t pos = tensor_name.find(":");
255 const string key =
256 (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
257 const auto it_tensor_info = tensor_name_to_dtype.find(key);
258 TensorInfo tensor_info;
259 tensor_info.set_name(tensor_name);
260 if (it_tensor_info != tensor_name_to_dtype.end()) {
261 tensor_info.set_dtype(it_tensor_info->second);
262 } else {
263 LOG(WARNING)
264 << "No dtype found for tensor with name: " << tensor_name << ". "
265 << "Building TensorInfo with only name for SignatureDef inputs. "
266 << "Downstream functionality including validation may be "
267 << "impacted.";
268 }
269 (*signature_def->mutable_inputs())[input_key] = tensor_info;
270 }
271
272 // Adds an entry to the `outputs` map of the supplied SignatureDef.
AddOutputToSignatureDef(const string & tensor_name,const std::unordered_map<string,DataType> & tensor_name_to_dtype,const string & output_key,SignatureDef * signature_def)273 void AddOutputToSignatureDef(
274 const string& tensor_name,
275 const std::unordered_map<string, DataType>& tensor_name_to_dtype,
276 const string& output_key, SignatureDef* signature_def) {
277 if (tensor_name.empty()) {
278 LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
279 "SignatureDef outputs.";
280 return;
281 }
282 // Extract the tensor-name in case the supplied string is a tensor-reference.
283 // Example: Extract "x" from "x:0".
284 std::size_t pos = tensor_name.find(":");
285 const string key =
286 (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
287 const auto it_tensor_info = tensor_name_to_dtype.find(key);
288 TensorInfo tensor_info;
289 tensor_info.set_name(tensor_name);
290 if (it_tensor_info != tensor_name_to_dtype.end()) {
291 tensor_info.set_dtype(it_tensor_info->second);
292 } else {
293 LOG(WARNING)
294 << "No dtype found for tensor with name: " << tensor_name << ". "
295 << "Building TensorInfo with only name for SignatureDef outputs."
296 << " Downstream functionality including validation may be "
297 << "impacted.";
298 }
299 (*signature_def->mutable_outputs())[output_key] = tensor_info;
300 }
301
302 // Builds a map from tensor name to the corresponding datatype, by parsing the
303 // MetaGraphDef.
BuildTensorNameToDtypeMap(const MetaGraphDef & meta_graph_def,std::unordered_map<string,DataType> * tensor_name_to_dtype)304 Status BuildTensorNameToDtypeMap(
305 const MetaGraphDef& meta_graph_def,
306 std::unordered_map<string, DataType>* tensor_name_to_dtype) {
307 GraphConstructorOptions opts;
308 Graph graph(OpRegistry::Global());
309 TF_RETURN_IF_ERROR(
310 ConvertGraphDefToGraph(opts, meta_graph_def.graph_def(), &graph));
311 for (Node* node : graph.nodes()) {
312 for (auto dt : node->output_types()) {
313 tensor_name_to_dtype->insert(std::make_pair(node->name(), dt));
314 }
315 }
316 return Status::OK();
317 }
318
319 // Converts SessionBundle signatures to SavedModel signature-defs.
ConvertSignaturesToSignatureDefs(MetaGraphDef * meta_graph_def)320 Status ConvertSignaturesToSignatureDefs(MetaGraphDef* meta_graph_def) {
321 Signatures signatures;
322 GetSignatures(*meta_graph_def, &signatures).IgnoreError();
323
324 // Build a map of tensor-names to the corresponding tensor-info with `name`
325 // and `dtype` fields.
326 std::unordered_map<string, DataType> tensor_name_to_dtype;
327 TF_RETURN_IF_ERROR(
328 BuildTensorNameToDtypeMap(*meta_graph_def, &tensor_name_to_dtype));
329
330 TF_RETURN_IF_ERROR(ConvertDefaultSignatureToSignatureDef(
331 signatures, tensor_name_to_dtype, meta_graph_def));
332 TF_RETURN_IF_ERROR(ConvertNamedSignaturesToSignatureDef(
333 signatures, tensor_name_to_dtype, meta_graph_def));
334 return Status::OK();
335 }
336
337 // Converts a SessionBundle to a SavedModelBundle.
ConvertSessionBundleToSavedModelBundle(SessionBundle & session_bundle,SavedModelBundle * saved_model_bundle)338 Status ConvertSessionBundleToSavedModelBundle(
339 SessionBundle& session_bundle, SavedModelBundle* saved_model_bundle) {
340 // Transfer ownership of the session from old to new.
341 saved_model_bundle->session = std::move(session_bundle.session);
342
343 // Copy the meta graph def from the SessionBundle to the SavedModelBundle.
344 saved_model_bundle->meta_graph_def = session_bundle.meta_graph_def;
345
346 // Convert signatures from session-bundle to signature-defs in
347 // saved-model-bundle.
348 return internal::ConvertSignaturesToSignatureDefs(
349 &saved_model_bundle->meta_graph_def);
350 }
351
352 } // namespace internal
353
LoadSessionBundleOrSavedModelBundle(const SessionOptions & session_options,const RunOptions & run_options,const string & export_dir,const std::unordered_set<string> & saved_model_tags,SavedModelBundle * saved_model_bundle,bool * is_session_bundle)354 Status LoadSessionBundleOrSavedModelBundle(
355 const SessionOptions& session_options, const RunOptions& run_options,
356 const string& export_dir,
357 const std::unordered_set<string>& saved_model_tags,
358 SavedModelBundle* saved_model_bundle, bool* is_session_bundle) {
359 if (is_session_bundle != nullptr) {
360 *is_session_bundle = false;
361 }
362 if (MaybeSavedModelDirectory(export_dir)) {
363 LOG(INFO)
364 << "Attempting to load native SavedModelBundle in bundle-shim from: "
365 << export_dir;
366
367 return LoadSavedModel(session_options, run_options, export_dir,
368 saved_model_tags, saved_model_bundle);
369 } else if (IsPossibleExportDirectory(export_dir)) {
370 LOG(ERROR) << "Found possible SessionBundle in export directory. "
371 "SessionBundle is deprecated. Use SavedModel instead.";
372 LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
373 "in bundle-shim from: "
374 << export_dir;
375 if (is_session_bundle != nullptr) {
376 *is_session_bundle = true;
377 }
378 return LoadSavedModelFromLegacySessionBundlePath(
379 session_options, run_options, export_dir, saved_model_bundle);
380 }
381 return Status(
382 error::Code::NOT_FOUND,
383 strings::StrCat(
384 "Specified file path does not appear to contain a:\n"
385 "- Session bundle (should have a file called `export.meta`)\n"
386 "- or, SavedModel bundle (should have a file called "
387 "`saved_model.pb`)\n"
388 "Specified file path: ",
389 export_dir));
390 }
391
392 } // namespace serving
393 } // namespace tensorflow
394