1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/session_bundle/signature.h"
17 
18 #include <memory>
19 
20 #include "google/protobuf/any.pb.h"
21 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_testutil.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/lib/core/stringpiece.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/public/session.h"
32 
33 namespace tensorflow {
34 namespace serving {
35 namespace {
36 
HasSubstr(StringPiece base,StringPiece substr)37 static bool HasSubstr(StringPiece base, StringPiece substr) {
38   bool ok = str_util::StrContains(base, substr);
39   EXPECT_TRUE(ok) << base << ", expected substring " << substr;
40   return ok;
41 }
42 
TEST(GetClassificationSignature,Basic)43 TEST(GetClassificationSignature, Basic) {
44   tensorflow::MetaGraphDef meta_graph_def;
45   Signatures signatures;
46   ClassificationSignature* input_signature =
47       signatures.mutable_default_signature()
48           ->mutable_classification_signature();
49   input_signature->mutable_input()->set_tensor_name("flow");
50   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
51       .mutable_any_list()
52       ->add_value()
53       ->PackFrom(signatures);
54 
55   ClassificationSignature signature;
56   const Status status = GetClassificationSignature(meta_graph_def, &signature);
57   TF_ASSERT_OK(status);
58   EXPECT_EQ(signature.input().tensor_name(), "flow");
59 }
60 
TEST(GetClassificationSignature,MissingSignature)61 TEST(GetClassificationSignature, MissingSignature) {
62   tensorflow::MetaGraphDef meta_graph_def;
63   Signatures signatures;
64   signatures.mutable_default_signature();
65   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
66       .mutable_any_list()
67       ->add_value()
68       ->PackFrom(signatures);
69 
70   ClassificationSignature signature;
71   const Status status = GetClassificationSignature(meta_graph_def, &signature);
72   ASSERT_FALSE(status.ok());
73   EXPECT_TRUE(str_util::StrContains(status.error_message(),
74                                     "Expected a classification signature"))
75       << status.error_message();
76 }
77 
TEST(GetClassificationSignature,WrongSignatureType)78 TEST(GetClassificationSignature, WrongSignatureType) {
79   tensorflow::MetaGraphDef meta_graph_def;
80   Signatures signatures;
81   signatures.mutable_default_signature()->mutable_regression_signature();
82   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
83       .mutable_any_list()
84       ->add_value()
85       ->PackFrom(signatures);
86 
87   ClassificationSignature signature;
88   const Status status = GetClassificationSignature(meta_graph_def, &signature);
89   ASSERT_FALSE(status.ok());
90   EXPECT_TRUE(str_util::StrContains(status.error_message(),
91                                     "Expected a classification signature"))
92       << status.error_message();
93 }
94 
TEST(GetNamedClassificationSignature,Basic)95 TEST(GetNamedClassificationSignature, Basic) {
96   tensorflow::MetaGraphDef meta_graph_def;
97   Signatures signatures;
98   ClassificationSignature* input_signature =
99       (*signatures.mutable_named_signatures())["foo"]
100           .mutable_classification_signature();
101   input_signature->mutable_input()->set_tensor_name("flow");
102   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
103       .mutable_any_list()
104       ->add_value()
105       ->PackFrom(signatures);
106 
107   ClassificationSignature signature;
108   const Status status =
109       GetNamedClassificationSignature("foo", meta_graph_def, &signature);
110   TF_ASSERT_OK(status);
111   EXPECT_EQ(signature.input().tensor_name(), "flow");
112 }
113 
TEST(GetNamedClassificationSignature,MissingSignature)114 TEST(GetNamedClassificationSignature, MissingSignature) {
115   tensorflow::MetaGraphDef meta_graph_def;
116   Signatures signatures;
117   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
118       .mutable_any_list()
119       ->add_value()
120       ->PackFrom(signatures);
121 
122   ClassificationSignature signature;
123   const Status status =
124       GetNamedClassificationSignature("foo", meta_graph_def, &signature);
125   ASSERT_FALSE(status.ok());
126   EXPECT_TRUE(str_util::StrContains(status.error_message(),
127                                     "Missing signature named \"foo\""))
128       << status.error_message();
129 }
130 
TEST(GetNamedClassificationSignature,WrongSignatureType)131 TEST(GetNamedClassificationSignature, WrongSignatureType) {
132   tensorflow::MetaGraphDef meta_graph_def;
133   Signatures signatures;
134   (*signatures.mutable_named_signatures())["foo"]
135       .mutable_regression_signature();
136   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
137       .mutable_any_list()
138       ->add_value()
139       ->PackFrom(signatures);
140 
141   ClassificationSignature signature;
142   const Status status =
143       GetNamedClassificationSignature("foo", meta_graph_def, &signature);
144   ASSERT_FALSE(status.ok());
145   EXPECT_TRUE(str_util::StrContains(
146       status.error_message(),
147       "Expected a classification signature for name \"foo\""))
148       << status.error_message();
149 }
150 
TEST(GetRegressionSignature,Basic)151 TEST(GetRegressionSignature, Basic) {
152   tensorflow::MetaGraphDef meta_graph_def;
153   Signatures signatures;
154   RegressionSignature* input_signature =
155       signatures.mutable_default_signature()->mutable_regression_signature();
156   input_signature->mutable_input()->set_tensor_name("flow");
157   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
158       .mutable_any_list()
159       ->add_value()
160       ->PackFrom(signatures);
161 
162   RegressionSignature signature;
163   const Status status = GetRegressionSignature(meta_graph_def, &signature);
164   TF_ASSERT_OK(status);
165   EXPECT_EQ(signature.input().tensor_name(), "flow");
166 }
167 
TEST(GetRegressionSignature,MissingSignature)168 TEST(GetRegressionSignature, MissingSignature) {
169   tensorflow::MetaGraphDef meta_graph_def;
170   Signatures signatures;
171   signatures.mutable_default_signature();
172   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
173       .mutable_any_list()
174       ->add_value()
175       ->PackFrom(signatures);
176 
177   RegressionSignature signature;
178   const Status status = GetRegressionSignature(meta_graph_def, &signature);
179   ASSERT_FALSE(status.ok());
180   EXPECT_TRUE(str_util::StrContains(status.error_message(),
181                                     "Expected a regression signature"))
182       << status.error_message();
183 }
184 
TEST(GetRegressionSignature,WrongSignatureType)185 TEST(GetRegressionSignature, WrongSignatureType) {
186   tensorflow::MetaGraphDef meta_graph_def;
187   Signatures signatures;
188   signatures.mutable_default_signature()->mutable_classification_signature();
189   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
190       .mutable_any_list()
191       ->add_value()
192       ->PackFrom(signatures);
193 
194   RegressionSignature signature;
195   const Status status = GetRegressionSignature(meta_graph_def, &signature);
196   ASSERT_FALSE(status.ok());
197   EXPECT_TRUE(str_util::StrContains(status.error_message(),
198                                     "Expected a regression signature"))
199       << status.error_message();
200 }
201 
TEST(GetNamedSignature,Basic)202 TEST(GetNamedSignature, Basic) {
203   tensorflow::MetaGraphDef meta_graph_def;
204   Signatures signatures;
205   ClassificationSignature* input_signature =
206       (*signatures.mutable_named_signatures())["foo"]
207           .mutable_classification_signature();
208   input_signature->mutable_input()->set_tensor_name("flow");
209   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
210       .mutable_any_list()
211       ->add_value()
212       ->PackFrom(signatures);
213 
214   Signature signature;
215   const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
216   TF_ASSERT_OK(status);
217   EXPECT_EQ(signature.classification_signature().input().tensor_name(), "flow");
218 }
219 
TEST(GetNamedSignature,MissingSignature)220 TEST(GetNamedSignature, MissingSignature) {
221   tensorflow::MetaGraphDef meta_graph_def;
222   Signatures signatures;
223   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
224       .mutable_any_list()
225       ->add_value()
226       ->PackFrom(signatures);
227 
228   Signature signature;
229   const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
230   ASSERT_FALSE(status.ok());
231   EXPECT_TRUE(str_util::StrContains(status.error_message(),
232                                     "Missing signature named \"foo\""))
233       << status.error_message();
234 }
235 
236 // MockSession used to test input and output interactions with a
237 // tensorflow::Session.
238 struct MockSession : public tensorflow::Session {
239   ~MockSession() override = default;
240 
Createtensorflow::serving::__anond1cc78060111::MockSession241   Status Create(const GraphDef& graph) override {
242     return errors::Unimplemented("Not implemented for mock.");
243   }
244 
Extendtensorflow::serving::__anond1cc78060111::MockSession245   Status Extend(const GraphDef& graph) override {
246     return errors::Unimplemented("Not implemented for mock.");
247   }
248 
249   // Sets the input and output arguments.
Runtensorflow::serving::__anond1cc78060111::MockSession250   Status Run(const std::vector<std::pair<string, Tensor>>& inputs_arg,
251              const std::vector<string>& output_tensor_names_arg,
252              const std::vector<string>& target_node_names_arg,
253              std::vector<Tensor>* outputs_arg) override {
254     inputs = inputs_arg;
255     output_tensor_names = output_tensor_names_arg;
256     target_node_names = target_node_names_arg;
257     *outputs_arg = outputs;
258     return status;
259   }
260 
Closetensorflow::serving::__anond1cc78060111::MockSession261   Status Close() override {
262     return errors::Unimplemented("Not implemented for mock.");
263   }
264 
ListDevicestensorflow::serving::__anond1cc78060111::MockSession265   Status ListDevices(std::vector<DeviceAttributes>* response) override {
266     return errors::Unimplemented("Not implemented for mock.");
267   }
268 
269   // Arguments stored on a Run call.
270   std::vector<std::pair<string, Tensor>> inputs;
271   std::vector<string> output_tensor_names;
272   std::vector<string> target_node_names;
273 
274   // Output argument set by Run; should be set before calling.
275   std::vector<Tensor> outputs;
276 
277   // Return value for Run; should be set before calling.
278   Status status;
279 };
280 
281 constexpr char kInputName[] = "in:0";
282 constexpr char kClassesName[] = "classes:0";
283 constexpr char kScoresName[] = "scores:0";
284 
285 class RunClassificationTest : public ::testing::Test {
286  public:
SetUp()287   void SetUp() override {
288     signature_.mutable_input()->set_tensor_name(kInputName);
289     signature_.mutable_classes()->set_tensor_name(kClassesName);
290     signature_.mutable_scores()->set_tensor_name(kScoresName);
291   }
292 
293  protected:
294   ClassificationSignature signature_;
295   Tensor input_tensor_;
296   Tensor classes_tensor_;
297   Tensor scores_tensor_;
298   MockSession session_;
299 };
300 
TEST_F(RunClassificationTest,Basic)301 TEST_F(RunClassificationTest, Basic) {
302   input_tensor_ = test::AsTensor<int>({99});
303   session_.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({2})};
304   const Status status = RunClassification(signature_, input_tensor_, &session_,
305                                           &classes_tensor_, &scores_tensor_);
306 
307   // Validate outputs.
308   TF_ASSERT_OK(status);
309   test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_);
310   test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_);
311 
312   // Validate inputs.
313   ASSERT_EQ(1, session_.inputs.size());
314   EXPECT_EQ(kInputName, session_.inputs[0].first);
315   test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
316                                session_.inputs[0].second);
317 
318   ASSERT_EQ(2, session_.output_tensor_names.size());
319   EXPECT_EQ(kClassesName, session_.output_tensor_names[0]);
320   EXPECT_EQ(kScoresName, session_.output_tensor_names[1]);
321 }
322 
TEST_F(RunClassificationTest,ClassesOnly)323 TEST_F(RunClassificationTest, ClassesOnly) {
324   input_tensor_ = test::AsTensor<int>({99});
325   session_.outputs = {test::AsTensor<int>({3})};
326   const Status status = RunClassification(signature_, input_tensor_, &session_,
327                                           &classes_tensor_, nullptr);
328 
329   // Validate outputs.
330   TF_ASSERT_OK(status);
331   test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_);
332 
333   // Validate inputs.
334   ASSERT_EQ(1, session_.inputs.size());
335   EXPECT_EQ(kInputName, session_.inputs[0].first);
336   test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
337                                session_.inputs[0].second);
338 
339   ASSERT_EQ(1, session_.output_tensor_names.size());
340   EXPECT_EQ(kClassesName, session_.output_tensor_names[0]);
341 }
342 
TEST_F(RunClassificationTest,ScoresOnly)343 TEST_F(RunClassificationTest, ScoresOnly) {
344   input_tensor_ = test::AsTensor<int>({99});
345   session_.outputs = {test::AsTensor<int>({2})};
346   const Status status = RunClassification(signature_, input_tensor_, &session_,
347                                           nullptr, &scores_tensor_);
348 
349   // Validate outputs.
350   TF_ASSERT_OK(status);
351   test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_);
352 
353   // Validate inputs.
354   ASSERT_EQ(1, session_.inputs.size());
355   EXPECT_EQ(kInputName, session_.inputs[0].first);
356   test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
357                                session_.inputs[0].second);
358 
359   ASSERT_EQ(1, session_.output_tensor_names.size());
360   EXPECT_EQ(kScoresName, session_.output_tensor_names[0]);
361 }
362 
TEST(RunClassification,RunNotOk)363 TEST(RunClassification, RunNotOk) {
364   ClassificationSignature signature;
365   signature.mutable_input()->set_tensor_name("in:0");
366   signature.mutable_classes()->set_tensor_name("classes:0");
367   Tensor input_tensor = test::AsTensor<int>({99});
368   MockSession session;
369   session.status = errors::DataLoss("Data is gone");
370   Tensor classes_tensor;
371   const Status status = RunClassification(signature, input_tensor, &session,
372                                           &classes_tensor, nullptr);
373   ASSERT_FALSE(status.ok());
374   EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone"))
375       << status.error_message();
376 }
377 
TEST(RunClassification,TooManyOutputs)378 TEST(RunClassification, TooManyOutputs) {
379   ClassificationSignature signature;
380   signature.mutable_input()->set_tensor_name("in:0");
381   signature.mutable_classes()->set_tensor_name("classes:0");
382   Tensor input_tensor = test::AsTensor<int>({99});
383   MockSession session;
384   session.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({4})};
385 
386   Tensor classes_tensor;
387   const Status status = RunClassification(signature, input_tensor, &session,
388                                           &classes_tensor, nullptr);
389   ASSERT_FALSE(status.ok());
390   EXPECT_TRUE(
391       str_util::StrContains(status.error_message(), "Expected 1 output"))
392       << status.error_message();
393 }
394 
TEST(RunClassification,WrongBatchOutputs)395 TEST(RunClassification, WrongBatchOutputs) {
396   ClassificationSignature signature;
397   signature.mutable_input()->set_tensor_name("in:0");
398   signature.mutable_classes()->set_tensor_name("classes:0");
399   Tensor input_tensor = test::AsTensor<int>({99, 100});
400   MockSession session;
401   session.outputs = {test::AsTensor<int>({3})};
402 
403   Tensor classes_tensor;
404   const Status status = RunClassification(signature, input_tensor, &session,
405                                           &classes_tensor, nullptr);
406   ASSERT_FALSE(status.ok());
407   EXPECT_TRUE(
408       str_util::StrContains(status.error_message(),
409                             "Input batch size did not match output batch size"))
410       << status.error_message();
411 }
412 
413 constexpr char kRegressionsName[] = "regressions:0";
414 
415 class RunRegressionTest : public ::testing::Test {
416  public:
SetUp()417   void SetUp() override {
418     signature_.mutable_input()->set_tensor_name(kInputName);
419     signature_.mutable_output()->set_tensor_name(kRegressionsName);
420   }
421 
422  protected:
423   RegressionSignature signature_;
424   Tensor input_tensor_;
425   Tensor output_tensor_;
426   MockSession session_;
427 };
428 
TEST_F(RunRegressionTest,Basic)429 TEST_F(RunRegressionTest, Basic) {
430   input_tensor_ = test::AsTensor<int>({99, 100});
431   session_.outputs = {test::AsTensor<float>({1, 2})};
432   const Status status =
433       RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
434 
435   // Validate outputs.
436   TF_ASSERT_OK(status);
437   test::ExpectTensorEqual<float>(test::AsTensor<float>({1, 2}), output_tensor_);
438 
439   // Validate inputs.
440   ASSERT_EQ(1, session_.inputs.size());
441   EXPECT_EQ(kInputName, session_.inputs[0].first);
442   test::ExpectTensorEqual<int>(test::AsTensor<int>({99, 100}),
443                                session_.inputs[0].second);
444 
445   ASSERT_EQ(1, session_.output_tensor_names.size());
446   EXPECT_EQ(kRegressionsName, session_.output_tensor_names[0]);
447 }
448 
TEST_F(RunRegressionTest,RunNotOk)449 TEST_F(RunRegressionTest, RunNotOk) {
450   input_tensor_ = test::AsTensor<int>({99});
451   session_.status = errors::DataLoss("Data is gone");
452   const Status status =
453       RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
454   ASSERT_FALSE(status.ok());
455   EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone"))
456       << status.error_message();
457 }
458 
TEST_F(RunRegressionTest,MismatchedSizeForBatchInputAndOutput)459 TEST_F(RunRegressionTest, MismatchedSizeForBatchInputAndOutput) {
460   input_tensor_ = test::AsTensor<int>({99, 100});
461   session_.outputs = {test::AsTensor<float>({3})};
462 
463   const Status status =
464       RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
465   ASSERT_FALSE(status.ok());
466   EXPECT_TRUE(
467       str_util::StrContains(status.error_message(),
468                             "Input batch size did not match output batch size"))
469       << status.error_message();
470 }
471 
TEST(SetAndGetSignatures,RoundTrip)472 TEST(SetAndGetSignatures, RoundTrip) {
473   tensorflow::MetaGraphDef meta_graph_def;
474   Signatures signatures;
475   signatures.mutable_default_signature()
476       ->mutable_classification_signature()
477       ->mutable_input()
478       ->set_tensor_name("in:0");
479   TF_ASSERT_OK(SetSignatures(signatures, &meta_graph_def));
480   Signatures read_signatures;
481   TF_ASSERT_OK(GetSignatures(meta_graph_def, &read_signatures));
482 
483   EXPECT_EQ("in:0", read_signatures.default_signature()
484                         .classification_signature()
485                         .input()
486                         .tensor_name());
487 }
488 
TEST(GetSignatures,MissingSignature)489 TEST(GetSignatures, MissingSignature) {
490   tensorflow::MetaGraphDef meta_graph_def;
491   Signatures read_signatures;
492   const auto status = GetSignatures(meta_graph_def, &read_signatures);
493   EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
494   EXPECT_TRUE(
495       str_util::StrContains(status.error_message(), "Expected exactly one"))
496       << status.error_message();
497 }
498 
TEST(GetSignatures,WrongProtoInAny)499 TEST(GetSignatures, WrongProtoInAny) {
500   tensorflow::MetaGraphDef meta_graph_def;
501   auto& collection_def = *(meta_graph_def.mutable_collection_def());
502   auto* any =
503       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
504   // Put an unexpected type into the Signatures Any.
505   any->PackFrom(TensorBinding());
506   Signatures read_signatures;
507   const auto status = GetSignatures(meta_graph_def, &read_signatures);
508   EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
509   EXPECT_TRUE(str_util::StrContains(status.error_message(),
510                                     "Expected Any type_url for: "
511                                     "tensorflow.serving.Signatures"))
512       << status.error_message();
513 }
514 
TEST(GetSignatures,JunkInAny)515 TEST(GetSignatures, JunkInAny) {
516   tensorflow::MetaGraphDef meta_graph_def;
517   auto& collection_def = *(meta_graph_def.mutable_collection_def());
518   auto* any =
519       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
520   // Create a valid Any then corrupt it.
521   any->PackFrom(Signatures());
522   any->set_value("junk junk");
523   Signatures read_signatures;
524   const auto status = GetSignatures(meta_graph_def, &read_signatures);
525   EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
526   EXPECT_TRUE(str_util::StrContains(status.error_message(), "Failed to unpack"))
527       << status.error_message();
528 }
529 
TEST(GetSignatures,DefaultAndNamedTogetherOK)530 TEST(GetSignatures, DefaultAndNamedTogetherOK) {
531   tensorflow::MetaGraphDef meta_graph_def;
532   auto& collection_def = *(meta_graph_def.mutable_collection_def());
533   auto* any =
534       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
535   Signatures signatures;
536   signatures.mutable_default_signature()
537       ->mutable_classification_signature()
538       ->mutable_input()
539       ->set_tensor_name("in:0");
540   ClassificationSignature* input_signature =
541       (*signatures.mutable_named_signatures())["foo"]
542           .mutable_classification_signature();
543   input_signature->mutable_input()->set_tensor_name("flow");
544 
545   any->PackFrom(signatures);
546   Signatures read_signatures;
547   const auto status = GetSignatures(meta_graph_def, &read_signatures);
548 
549   EXPECT_TRUE(status.ok());
550 }
551 
552 // Check that we only have one 'Signatures' entry in the collection_def map.
553 // Note that each such object can have multiple named_signatures inside of it.
TEST(GetSignatures,MultipleSignaturesNotOK)554 TEST(GetSignatures, MultipleSignaturesNotOK) {
555   tensorflow::MetaGraphDef meta_graph_def;
556   auto& collection_def = *(meta_graph_def.mutable_collection_def());
557   auto* any =
558       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
559   Signatures signatures;
560   signatures.mutable_default_signature()
561       ->mutable_classification_signature()
562       ->mutable_input()
563       ->set_tensor_name("in:0");
564   any->PackFrom(signatures);
565 
566   // Add another signatures object.
567   any =
568       collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
569   any->PackFrom(signatures);
570   Signatures read_signatures;
571   const auto status = GetSignatures(meta_graph_def, &read_signatures);
572   EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
573   EXPECT_TRUE(
574       str_util::StrContains(status.error_message(), "Expected exactly one"))
575       << status.error_message();
576 }
577 
578 // GenericSignature test fixture that contains a signature initialized with two
579 // bound Tensors.
580 class GenericSignatureTest : public ::testing::Test {
581  protected:
GenericSignatureTest()582   GenericSignatureTest() {
583     TensorBinding binding;
584     binding.set_tensor_name("graph_A");
585     signature_.mutable_map()->insert({"logical_A", binding});
586 
587     binding.set_tensor_name("graph_B");
588     signature_.mutable_map()->insert({"logical_B", binding});
589   }
590 
591   // GenericSignature that contains two bound Tensors.
592   GenericSignature signature_;
593 };
594 
595 // GenericSignature tests.
596 
TEST_F(GenericSignatureTest,GetGenericSignatureBasic)597 TEST_F(GenericSignatureTest, GetGenericSignatureBasic) {
598   Signature expected_signature;
599   expected_signature.mutable_generic_signature()->MergeFrom(signature_);
600 
601   tensorflow::MetaGraphDef meta_graph_def;
602   Signatures signatures;
603   signatures.mutable_named_signatures()->insert(
604       {"generic_bindings", expected_signature});
605   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
606       .mutable_any_list()
607       ->add_value()
608       ->PackFrom(signatures);
609 
610   GenericSignature actual_signature;
611   TF_ASSERT_OK(GetGenericSignature("generic_bindings", meta_graph_def,
612                                    &actual_signature));
613   ASSERT_EQ("graph_A", actual_signature.map().at("logical_A").tensor_name());
614   ASSERT_EQ("graph_B", actual_signature.map().at("logical_B").tensor_name());
615 }
616 
TEST(GetGenericSignature,MissingSignature)617 TEST(GetGenericSignature, MissingSignature) {
618   tensorflow::MetaGraphDef meta_graph_def;
619   Signatures signatures;
620   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
621       .mutable_any_list()
622       ->add_value()
623       ->PackFrom(signatures);
624 
625   GenericSignature signature;
626   const Status status =
627       GetGenericSignature("generic_bindings", meta_graph_def, &signature);
628   ASSERT_FALSE(status.ok());
629   EXPECT_TRUE(HasSubstr(status.error_message(),
630                         "Missing generic signature named \"generic_bindings\""))
631       << status.error_message();
632 }
633 
TEST(GetGenericSignature,WrongSignatureType)634 TEST(GetGenericSignature, WrongSignatureType) {
635   tensorflow::MetaGraphDef meta_graph_def;
636   Signatures signatures;
637   (*signatures.mutable_named_signatures())["generic_bindings"]
638       .mutable_regression_signature();
639   (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
640       .mutable_any_list()
641       ->add_value()
642       ->PackFrom(signatures);
643 
644   GenericSignature signature;
645   const Status status =
646       GetGenericSignature("generic_bindings", meta_graph_def, &signature);
647   ASSERT_FALSE(status.ok());
648   EXPECT_TRUE(str_util::StrContains(status.error_message(),
649                                     "Expected a generic signature:"))
650       << status.error_message();
651 }
652 
653 // BindGeneric Tests.
654 
TEST_F(GenericSignatureTest,BindGenericInputsBasic)655 TEST_F(GenericSignatureTest, BindGenericInputsBasic) {
656   const std::vector<std::pair<string, Tensor>> inputs = {
657       {"logical_A", test::AsTensor<float>({-1.0})},
658       {"logical_B", test::AsTensor<float>({-2.0})}};
659 
660   std::vector<std::pair<string, Tensor>> bound_inputs;
661   TF_ASSERT_OK(BindGenericInputs(signature_, inputs, &bound_inputs));
662 
663   EXPECT_EQ("graph_A", bound_inputs[0].first);
664   EXPECT_EQ("graph_B", bound_inputs[1].first);
665   test::ExpectTensorEqual<float>(test::AsTensor<float>({-1.0}),
666                                  bound_inputs[0].second);
667   test::ExpectTensorEqual<float>(test::AsTensor<float>({-2.0}),
668                                  bound_inputs[1].second);
669 }
670 
TEST_F(GenericSignatureTest,BindGenericInputsMissingBinding)671 TEST_F(GenericSignatureTest, BindGenericInputsMissingBinding) {
672   const std::vector<std::pair<string, Tensor>> inputs = {
673       {"logical_A", test::AsTensor<float>({-42.0})},
674       {"logical_MISSING", test::AsTensor<float>({-43.0})}};
675 
676   std::vector<std::pair<string, Tensor>> bound_inputs;
677   const Status status = BindGenericInputs(signature_, inputs, &bound_inputs);
678   ASSERT_FALSE(status.ok());
679 }
680 
TEST_F(GenericSignatureTest,BindGenericNamesBasic)681 TEST_F(GenericSignatureTest, BindGenericNamesBasic) {
682   const std::vector<string> input_names = {"logical_B", "logical_A"};
683   std::vector<string> bound_names;
684   TF_ASSERT_OK(BindGenericNames(signature_, input_names, &bound_names));
685 
686   EXPECT_EQ("graph_B", bound_names[0]);
687   EXPECT_EQ("graph_A", bound_names[1]);
688 }
689 
TEST_F(GenericSignatureTest,BindGenericNamesMissingBinding)690 TEST_F(GenericSignatureTest, BindGenericNamesMissingBinding) {
691   const std::vector<string> input_names = {"logical_B", "logical_MISSING"};
692   std::vector<string> bound_names;
693   const Status status = BindGenericNames(signature_, input_names, &bound_names);
694   ASSERT_FALSE(status.ok());
695 }
696 
697 }  // namespace
698 }  // namespace serving
699 }  // namespace tensorflow
700