1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/session_bundle/bundle_shim.h"
17 
18 #include "google/protobuf/any.pb.h"
19 #include "tensorflow/cc/saved_model/signature_constants.h"
20 #include "tensorflow/cc/saved_model/tag_constants.h"
21 #include "tensorflow/contrib/session_bundle/test_util.h"
22 #include "tensorflow/core/example/example.pb.h"
23 #include "tensorflow/core/example/feature.pb.h"
24 #include "tensorflow/core/framework/tensor_testutil.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/lib/io/path.h"
27 #include "tensorflow/core/protobuf/meta_graph.pb.h"
28 
29 namespace tensorflow {
30 namespace serving {
31 namespace internal {
32 namespace {
33 
34 constexpr char kSessionBundlePath[] =
35     "session_bundle/testdata/half_plus_two/00000123";
36 constexpr char kSavedModelBundlePath[] =
37     "cc/saved_model/testdata/half_plus_two/00000123";
38 
MakeSerializedExample(float x)39 string MakeSerializedExample(float x) {
40   tensorflow::Example example;
41   auto* feature_map = example.mutable_features()->mutable_feature();
42   (*feature_map)["x"].mutable_float_list()->add_value(x);
43   return example.SerializeAsString();
44 }
45 
ValidateHalfPlusTwo(const SavedModelBundle & saved_model_bundle,const string & input_tensor_name,const string & output_tensor_name)46 void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle,
47                          const string& input_tensor_name,
48                          const string& output_tensor_name) {
49   // Validate the half plus two behavior.
50   std::vector<string> serialized_examples;
51   for (float x : {0, 1, 2, 3}) {
52     serialized_examples.push_back(MakeSerializedExample(x));
53   }
54   Tensor input = test::AsTensor<string>(serialized_examples, TensorShape({4}));
55 
56   std::vector<Tensor> outputs;
57   TF_ASSERT_OK(saved_model_bundle.session->Run(
58       {{input_tensor_name, input}}, {output_tensor_name}, {}, &outputs));
59   ASSERT_EQ(outputs.size(), 1);
60   test::ExpectTensorEqual<float>(
61       outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
62 }
63 
LoadAndValidateSavedModelBundle(const string & export_dir,const std::unordered_set<string> & tags,const string & signature_def_key,bool expect_session_bundle)64 void LoadAndValidateSavedModelBundle(const string& export_dir,
65                                      const std::unordered_set<string>& tags,
66                                      const string& signature_def_key,
67                                      bool expect_session_bundle) {
68   SessionOptions session_options;
69   RunOptions run_options;
70   SavedModelBundle saved_model_bundle;
71   bool is_session_bundle = false;
72   TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
73       session_options, run_options, export_dir, tags, &saved_model_bundle,
74       &is_session_bundle));
75   EXPECT_EQ(expect_session_bundle, is_session_bundle);
76   const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
77   const auto& signature_def_map = meta_graph_def.signature_def();
78 
79   const auto& regression_entry = signature_def_map.find(signature_def_key);
80   ASSERT_FALSE(regression_entry == signature_def_map.end());
81   SignatureDef regression_signature_def = regression_entry->second;
82 
83   EXPECT_EQ(1, regression_signature_def.inputs_size());
84   ASSERT_FALSE(regression_signature_def.inputs().find(kRegressInputs) ==
85                regression_signature_def.inputs().end());
86   TensorInfo input_tensor_info =
87       regression_signature_def.inputs().find(kRegressInputs)->second;
88   EXPECT_EQ(1, regression_signature_def.outputs_size());
89   // Ensure the TensorInfo has dtype populated.
90   EXPECT_EQ(DT_STRING, input_tensor_info.dtype());
91 
92   ASSERT_FALSE(regression_signature_def.outputs().find(kRegressOutputs) ==
93                regression_signature_def.outputs().end());
94   TensorInfo output_tensor_info =
95       regression_signature_def.outputs().find(kRegressOutputs)->second;
96   // Ensure the TensorInfo has dtype populated.
97   EXPECT_EQ(DT_FLOAT, output_tensor_info.dtype());
98   ValidateHalfPlusTwo(saved_model_bundle, input_tensor_info.name(),
99                       output_tensor_info.name());
100 }
101 
102 // Helper function to validate that the SignatureDef found in the MetaGraphDef
103 // with the provided key has the expected string representation.
ValidateSignatureDef(const MetaGraphDef & meta_graph_def,const string & key,const string & expected_string_signature_def)104 void ValidateSignatureDef(const MetaGraphDef& meta_graph_def, const string& key,
105                           const string& expected_string_signature_def) {
106   tensorflow::SignatureDef expected_signature;
107   CHECK(protobuf::TextFormat::ParseFromString(expected_string_signature_def,
108                                               &expected_signature));
109   auto iter = meta_graph_def.signature_def().find(key);
110   ASSERT_TRUE(iter != meta_graph_def.signature_def().end());
111   EXPECT_EQ(expected_signature.DebugString(), iter->second.DebugString());
112 }
113 
114 // Checks that the input map in a signature def is populated correctly.
TEST(BundleShimTest,AddInputToSignatureDef)115 TEST(BundleShimTest, AddInputToSignatureDef) {
116   SignatureDef signature_def;
117   const string tensor_name = "foo_tensor";
118   const string map_key = "foo_key";
119 
120   // Build a map of tensor-name to dtype, for the unit-test.
121   std::unordered_map<string, DataType> tensor_name_to_dtype;
122   tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
123 
124   AddInputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
125                          &signature_def);
126   EXPECT_EQ(1, signature_def.inputs_size());
127   EXPECT_EQ(tensor_name, signature_def.inputs().find(map_key)->second.name());
128 }
129 
130 // Checks that the output map in a signature def is populated correctly.
TEST(BundleShimTest,AddOutputToSignatureDef)131 TEST(BundleShimTest, AddOutputToSignatureDef) {
132   SignatureDef signature_def;
133   const string tensor_name = "foo_tensor";
134   const string map_key = "foo_key";
135 
136   // Build a map of tensor-name to dtype, for the unit-test.
137   std::unordered_map<string, DataType> tensor_name_to_dtype;
138   tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
139 
140   AddOutputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
141                           &signature_def);
142   EXPECT_EQ(1, signature_def.outputs_size());
143   EXPECT_EQ(tensor_name, signature_def.outputs().find(map_key)->second.name());
144 }
145 
146 // Checks that no signature defs are added if the default signature is missing.
TEST(BundleShimTest,DefaultSignatureMissing)147 TEST(BundleShimTest, DefaultSignatureMissing) {
148   MetaGraphDef meta_graph_def;
149   // Signatures signatures;
150   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
151   EXPECT_EQ(0, meta_graph_def.signature_def_size());
152 }
153 
154 // Checks that no signature defs are added if the default signature is empty.
TEST(BundleShimTest,DefaultSignatureEmpty)155 TEST(BundleShimTest, DefaultSignatureEmpty) {
156   Signatures signatures;
157   signatures.mutable_default_signature();
158 
159   MetaGraphDef meta_graph_def;
160   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
161       .mutable_any_list()
162       ->add_value()
163       ->PackFrom(signatures);
164   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
165   EXPECT_EQ(0, meta_graph_def.signature_def_size());
166 }
167 
168 // Checks the conversion to signature def for a regression default signature.
TEST(BundleShimTest,DefaultSignatureRegression)169 TEST(BundleShimTest, DefaultSignatureRegression) {
170   Signatures signatures;
171   RegressionSignature* regression_signature =
172       signatures.mutable_default_signature()->mutable_regression_signature();
173   regression_signature->mutable_input()->set_tensor_name("foo-input");
174   regression_signature->mutable_output()->set_tensor_name("foo-output");
175   MetaGraphDef meta_graph_def;
176   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
177       .mutable_any_list()
178       ->add_value()
179       ->PackFrom(signatures);
180   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
181   EXPECT_EQ(1, meta_graph_def.signature_def_size());
182   const auto actual_signature_def =
183       meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
184   EXPECT_EQ("foo-input", actual_signature_def->second.inputs()
185                              .find(kRegressInputs)
186                              ->second.name());
187   EXPECT_EQ("foo-output", actual_signature_def->second.outputs()
188                               .find(kRegressOutputs)
189                               ->second.name());
190   EXPECT_EQ(kRegressMethodName, actual_signature_def->second.method_name());
191 }
192 
193 // Checks the conversion to signature def for a classification default
194 // signature.
TEST(BundleShimTest,DefaultSignatureClassification)195 TEST(BundleShimTest, DefaultSignatureClassification) {
196   Signatures signatures;
197   ClassificationSignature* classification_signature =
198       signatures.mutable_default_signature()
199           ->mutable_classification_signature();
200   classification_signature->mutable_input()->set_tensor_name("foo-input");
201   classification_signature->mutable_classes()->set_tensor_name("foo-classes");
202   classification_signature->mutable_scores()->set_tensor_name("foo-scores");
203   MetaGraphDef meta_graph_def;
204   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
205       .mutable_any_list()
206       ->add_value()
207       ->PackFrom(signatures);
208   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
209   EXPECT_EQ(1, meta_graph_def.signature_def_size());
210   const auto actual_signature_def =
211       meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
212   EXPECT_EQ("foo-input", actual_signature_def->second.inputs()
213                              .find(kClassifyInputs)
214                              ->second.name());
215   EXPECT_EQ("foo-classes", actual_signature_def->second.outputs()
216                                .find(kClassifyOutputClasses)
217                                ->second.name());
218   EXPECT_EQ("foo-scores", actual_signature_def->second.outputs()
219                               .find(kClassifyOutputScores)
220                               ->second.name());
221   EXPECT_EQ(kClassifyMethodName, actual_signature_def->second.method_name());
222 }
223 
224 // Checks that generic default signatures are not up converted.
TEST(BundleShimTest,DefaultSignatureGeneric)225 TEST(BundleShimTest, DefaultSignatureGeneric) {
226   TensorBinding input_binding;
227   input_binding.set_tensor_name("foo-input");
228 
229   TensorBinding output_binding;
230   output_binding.set_tensor_name("foo-output");
231 
232   Signatures signatures;
233   GenericSignature* generic_signature =
234       signatures.mutable_default_signature()->mutable_generic_signature();
235   generic_signature->mutable_map()->insert({kPredictInputs, input_binding});
236   generic_signature->mutable_map()->insert({kPredictOutputs, output_binding});
237 
238   MetaGraphDef meta_graph_def;
239   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
240       .mutable_any_list()
241       ->add_value()
242       ->PackFrom(signatures);
243   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
244   EXPECT_EQ(0, meta_graph_def.signature_def_size());
245 }
246 
TEST(BundleShimTest,NamedRegressionSignatures)247 TEST(BundleShimTest, NamedRegressionSignatures) {
248   Signatures signatures;
249 
250   RegressionSignature* foo_regression_signature =
251       (*signatures.mutable_named_signatures())["foo"]
252           .mutable_regression_signature();
253   foo_regression_signature->mutable_input()->set_tensor_name("foo-input");
254   foo_regression_signature->mutable_output()->set_tensor_name("foo-output");
255 
256   RegressionSignature* bar_regression_signature =
257       (*signatures.mutable_named_signatures())["bar"]
258           .mutable_regression_signature();
259   bar_regression_signature->mutable_input()->set_tensor_name("bar-input");
260   bar_regression_signature->mutable_output()->set_tensor_name("bar-output");
261 
262   MetaGraphDef meta_graph_def;
263   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
264       .mutable_any_list()
265       ->add_value()
266       ->PackFrom(signatures);
267   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
268   ASSERT_EQ(2, meta_graph_def.signature_def_size());
269 
270   ValidateSignatureDef(meta_graph_def, "foo",
271                        "inputs { "
272                        "  key: \"inputs\" "
273                        "  value { "
274                        "name: \"foo-input\" "
275                        "  } "
276                        "} "
277                        "outputs { "
278                        "  key: \"outputs\" "
279                        "  value { "
280                        "    name: \"foo-output\" "
281                        "  } "
282                        "} "
283                        "method_name: \"tensorflow/serving/regress\" ");
284   ValidateSignatureDef(meta_graph_def, "bar",
285                        "inputs { "
286                        "  key: \"inputs\" "
287                        "  value { "
288                        "name: \"bar-input\" "
289                        "  } "
290                        "} "
291                        "outputs { "
292                        "  key: \"outputs\" "
293                        "  value { "
294                        "    name: \"bar-output\" "
295                        "  } "
296                        "} "
297                        "method_name: \"tensorflow/serving/regress\" ");
298 }
299 
TEST(BundleShimTest,NamedClassificationSignatures)300 TEST(BundleShimTest, NamedClassificationSignatures) {
301   Signatures signatures;
302 
303   ClassificationSignature* foo_classification_signature =
304       (*signatures.mutable_named_signatures())["foo"]
305           .mutable_classification_signature();
306   foo_classification_signature->mutable_input()->set_tensor_name("foo-input");
307   foo_classification_signature->mutable_classes()->set_tensor_name(
308       "foo-classes");
309 
310   ClassificationSignature* bar_classification_signature =
311       (*signatures.mutable_named_signatures())["bar"]
312           .mutable_classification_signature();
313   bar_classification_signature->mutable_input()->set_tensor_name("bar-input");
314   bar_classification_signature->mutable_scores()->set_tensor_name("bar-scores");
315 
316   MetaGraphDef meta_graph_def;
317   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
318       .mutable_any_list()
319       ->add_value()
320       ->PackFrom(signatures);
321   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
322   ASSERT_EQ(2, meta_graph_def.signature_def_size());
323 
324   ValidateSignatureDef(meta_graph_def, "foo",
325                        "inputs { "
326                        "  key: \"inputs\" "
327                        "  value { "
328                        "name: \"foo-input\" "
329                        "  } "
330                        "} "
331                        "outputs { "
332                        "  key: \"classes\" "
333                        "  value { "
334                        "    name: \"foo-classes\" "
335                        "  } "
336                        "} "
337                        "method_name: \"tensorflow/serving/classify\" ");
338   ValidateSignatureDef(meta_graph_def, "bar",
339                        "inputs { "
340                        "  key: \"inputs\" "
341                        "  value { "
342                        "name: \"bar-input\" "
343                        "  } "
344                        "} "
345                        "outputs { "
346                        "  key: \"scores\" "
347                        "  value { "
348                        "    name: \"bar-scores\" "
349                        "  } "
350                        "} "
351                        "method_name: \"tensorflow/serving/classify\" ");
352 }
353 
354 // Checks the Predict SignatureDef created when the named signatures have
355 // `inputs` and `outputs`.
TEST(BundleShimTest,NamedSignatureGenericInputsAndOutputs)356 TEST(BundleShimTest, NamedSignatureGenericInputsAndOutputs) {
357   TensorBinding input_binding;
358   input_binding.set_tensor_name("foo-input");
359 
360   TensorBinding output_binding;
361   output_binding.set_tensor_name("foo-output");
362 
363   Signatures signatures;
364   GenericSignature* input_generic_signature =
365       (*signatures.mutable_named_signatures())[kPredictInputs]
366           .mutable_generic_signature();
367   input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
368 
369   GenericSignature* output_generic_signature =
370       (*signatures.mutable_named_signatures())[kPredictOutputs]
371           .mutable_generic_signature();
372   output_generic_signature->mutable_map()->insert(
373       {"foo-output", output_binding});
374 
375   MetaGraphDef meta_graph_def;
376   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
377       .mutable_any_list()
378       ->add_value()
379       ->PackFrom(signatures);
380   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
381   EXPECT_EQ(1, meta_graph_def.signature_def_size());
382   const auto actual_signature_def =
383       meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
384   ASSERT_FALSE(actual_signature_def == meta_graph_def.signature_def().end());
385   ASSERT_FALSE(actual_signature_def->second.inputs().find("foo-input") ==
386                actual_signature_def->second.inputs().end());
387   EXPECT_EQ(
388       "foo-input",
389       actual_signature_def->second.inputs().find("foo-input")->second.name());
390   ASSERT_FALSE(actual_signature_def->second.outputs().find("foo-output") ==
391                actual_signature_def->second.outputs().end());
392   EXPECT_EQ(
393       "foo-output",
394       actual_signature_def->second.outputs().find("foo-output")->second.name());
395   EXPECT_EQ(kPredictMethodName, actual_signature_def->second.method_name());
396 }
397 
398 // Checks that a signature def is not added if the named signatures is generic
399 // but does not have `inputs` and `outputs`.
TEST(BundleShimTest,NamedSignatureGenericNoInputsOrOutputs)400 TEST(BundleShimTest, NamedSignatureGenericNoInputsOrOutputs) {
401   TensorBinding input_binding;
402   input_binding.set_tensor_name("foo-input");
403 
404   TensorBinding output_binding;
405   output_binding.set_tensor_name("foo-output");
406 
407   Signatures signatures;
408   GenericSignature* generic_signature =
409       (*signatures.mutable_named_signatures())["unknown"]
410           .mutable_generic_signature();
411   generic_signature->mutable_map()->insert({kPredictInputs, input_binding});
412   generic_signature->mutable_map()->insert({kPredictOutputs, output_binding});
413 
414   MetaGraphDef meta_graph_def;
415   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
416       .mutable_any_list()
417       ->add_value()
418       ->PackFrom(signatures);
419   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
420   EXPECT_EQ(0, meta_graph_def.signature_def_size());
421 }
422 
423 // Checks that a signature def is not added when the named signatures have only
424 // one of `inputs` and `outputs`.
TEST(BundleShimTest,NamedSignatureGenericOnlyInput)425 TEST(BundleShimTest, NamedSignatureGenericOnlyInput) {
426   TensorBinding input_binding;
427   input_binding.set_tensor_name("foo-input");
428 
429   Signatures signatures;
430   GenericSignature* input_generic_signature =
431       (*signatures.mutable_named_signatures())[kPredictInputs]
432           .mutable_generic_signature();
433   input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
434 
435   MetaGraphDef meta_graph_def;
436   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
437       .mutable_any_list()
438       ->add_value()
439       ->PackFrom(signatures);
440   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
441   EXPECT_EQ(0, meta_graph_def.signature_def_size());
442 }
443 
444 // Tests up-conversion of Signatures to SignatureDefs when both `default` and
445 // `named` signatures are present.
TEST(BundleShimTest,DefaultAndNamedSignatureWithPredict)446 TEST(BundleShimTest, DefaultAndNamedSignatureWithPredict) {
447   Signatures signatures;
448 
449   // Build a generic signature corresponding to `inputs` and add it to the
450   // Signatures to up-convert.
451   TensorBinding input_binding;
452   input_binding.set_tensor_name("foo-input");
453   GenericSignature* input_generic_signature =
454       (*signatures.mutable_named_signatures())[kPredictInputs]
455           .mutable_generic_signature();
456   input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
457 
458   // Build a generic signature corresponding to `outputs` and add it to the
459   // Signatures to up-convert.
460   TensorBinding output_binding;
461   output_binding.set_tensor_name("foo-output");
462   GenericSignature* output_generic_signature =
463       (*signatures.mutable_named_signatures())[kPredictOutputs]
464           .mutable_generic_signature();
465   output_generic_signature->mutable_map()->insert(
466       {"foo-output", output_binding});
467 
468   // Build a regression signature and set it as the default signature.
469   RegressionSignature* inputs_regression_signature =
470       (*signatures.mutable_default_signature()).mutable_regression_signature();
471   inputs_regression_signature->mutable_input()->set_tensor_name("bar-input");
472 
473   // Up-convert the available signatures to SignatureDefs.
474   MetaGraphDef meta_graph_def;
475   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
476       .mutable_any_list()
477       ->add_value()
478       ->PackFrom(signatures);
479   TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
480   EXPECT_EQ(2, meta_graph_def.signature_def_size());
481 
482   // Verify that the default regression signature is converted to a
483   // SignatureDef that corresponds to the kDefaultServingSignatureDefKey.
484   const auto actual_signature_def_regress =
485       meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
486   ASSERT_FALSE(actual_signature_def_regress ==
487                meta_graph_def.signature_def().end());
488   ASSERT_FALSE(
489       actual_signature_def_regress->second.inputs().find(kRegressInputs) ==
490       actual_signature_def_regress->second.inputs().end());
491 
492   // Verify that the `Predict` SignatureDef is created under a different key.
493   const auto actual_signature_def_predict = meta_graph_def.signature_def().find(
494       strings::StrCat(kDefaultServingSignatureDefKey, "_from_named"));
495   ASSERT_FALSE(actual_signature_def_predict ==
496                meta_graph_def.signature_def().end());
497   ASSERT_FALSE(
498       actual_signature_def_predict->second.inputs().find("foo-input") ==
499       actual_signature_def_predict->second.inputs().end());
500   EXPECT_EQ("foo-input", actual_signature_def_predict->second.inputs()
501                              .find("foo-input")
502                              ->second.name());
503   ASSERT_FALSE(
504       actual_signature_def_predict->second.outputs().find("foo-output") ==
505       actual_signature_def_predict->second.outputs().end());
506   EXPECT_EQ("foo-output", actual_signature_def_predict->second.outputs()
507                               .find("foo-output")
508                               ->second.name());
509   EXPECT_EQ(kPredictMethodName,
510             actual_signature_def_predict->second.method_name());
511 }
512 
513 // Checks a basic up conversion for half plus two for SessionBundle.
TEST(BundleShimTest,BasicExportSessionBundle)514 TEST(BundleShimTest, BasicExportSessionBundle) {
515   const std::unordered_set<string> tags = {"tag"};
516   const string session_bundle_export_dir =
517       test_util::TestSrcDirPath(kSessionBundlePath);
518   LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
519                                   kDefaultServingSignatureDefKey,
520                                   /*expect_session_bundle=*/true);
521 
522   // Verify that the named signature is also present.
523   SessionOptions session_options;
524   RunOptions run_options;
525   SavedModelBundle saved_model_bundle;
526   TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
527                                                    session_bundle_export_dir,
528                                                    tags, &saved_model_bundle));
529   const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
530   const auto& signature_def_map = meta_graph_def.signature_def();
531   bool found_named_signature = false;
532   for (const auto& entry : signature_def_map) {
533     const string& key = entry.first;
534     const SignatureDef& signature_def = entry.second;
535 
536     // We're looking for the key that is *not* kDefaultServingSignatureDefKey.
537     if (key == kDefaultServingSignatureDefKey) {
538       continue;
539     }
540     found_named_signature = true;
541 
542     EXPECT_EQ(1, signature_def.inputs_size());
543     const auto it_inputs_x = signature_def.inputs().find("x");
544     EXPECT_FALSE(it_inputs_x == signature_def.inputs().end());
545     // Ensure the TensorInfo has name and dtype populated.
546     const TensorInfo& tensor_info_x = it_inputs_x->second;
547     EXPECT_EQ("x:0", tensor_info_x.name());
548     EXPECT_EQ(DT_FLOAT, tensor_info_x.dtype());
549 
550     EXPECT_EQ(1, signature_def.outputs_size());
551     const auto it_outputs_y = signature_def.outputs().find("y");
552     EXPECT_FALSE(it_outputs_y == signature_def.outputs().end());
553     // Ensure the TensorInfo has name and dtype populated.
554     const TensorInfo& tensor_info_y = it_outputs_y->second;
555     EXPECT_EQ("y:0", tensor_info_y.name());
556     EXPECT_EQ(DT_FLOAT, tensor_info_y.dtype());
557   }
558   EXPECT_TRUE(found_named_signature);
559 }
560 
561 // Checks a basic load for half plus two for SavedModelBundle.
TEST(BundleShimTest,BasicExportSavedModel)562 TEST(BundleShimTest, BasicExportSavedModel) {
563   const string saved_model_bundle_export_dir =
564       io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
565   LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
566                                   {kSavedModelTagServe}, "regress_x_to_y",
567                                   /*expect_session_bundle=*/false);
568 }
569 
570 // Checks a basic load fails with an invalid export path.
TEST(BundleShimTest,InvalidPath)571 TEST(BundleShimTest, InvalidPath) {
572   const string invalid_export_dir = testing::TensorFlowSrcRoot();
573   SessionOptions session_options;
574   RunOptions run_options;
575   SavedModelBundle saved_model_bundle;
576   Status status = LoadSessionBundleOrSavedModelBundle(
577       session_options, run_options, invalid_export_dir, {kSavedModelTagServe},
578       &saved_model_bundle);
579   EXPECT_EQ(error::Code::NOT_FOUND, status.code());
580 }
581 
582 // Checks that if loading a session bundle fails, the error is propagated to
583 // LoadSessionBundleOrSavedModelBundle().
TEST(BundleShimTest,LoadSessionBundleError)584 TEST(BundleShimTest, LoadSessionBundleError) {
585   const string session_bundle_export_dir =
586       test_util::TestSrcDirPath(kSessionBundlePath);
587   SessionOptions session_options;
588   RunOptions run_options;
589   // Invalid threadpool index to use for session-run calls.
590   run_options.set_inter_op_thread_pool(100);
591   SavedModelBundle saved_model_bundle;
592   EXPECT_FALSE(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
593                                                    session_bundle_export_dir,
594                                                    {"tag"}, &saved_model_bundle)
595                    .ok());
596 }
597 
598 }  // namespace
599 }  // namespace internal
600 }  // namespace serving
601 }  // namespace tensorflow
602