1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/contrib/session_bundle/signature.h"
17
18 #include <memory>
19
20 #include "google/protobuf/any.pb.h"
21 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_testutil.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/lib/core/stringpiece.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/platform/test.h"
31 #include "tensorflow/core/public/session.h"
32
33 namespace tensorflow {
34 namespace serving {
35 namespace {
36
HasSubstr(StringPiece base,StringPiece substr)37 static bool HasSubstr(StringPiece base, StringPiece substr) {
38 bool ok = str_util::StrContains(base, substr);
39 EXPECT_TRUE(ok) << base << ", expected substring " << substr;
40 return ok;
41 }
42
TEST(GetClassificationSignature,Basic)43 TEST(GetClassificationSignature, Basic) {
44 tensorflow::MetaGraphDef meta_graph_def;
45 Signatures signatures;
46 ClassificationSignature* input_signature =
47 signatures.mutable_default_signature()
48 ->mutable_classification_signature();
49 input_signature->mutable_input()->set_tensor_name("flow");
50 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
51 .mutable_any_list()
52 ->add_value()
53 ->PackFrom(signatures);
54
55 ClassificationSignature signature;
56 const Status status = GetClassificationSignature(meta_graph_def, &signature);
57 TF_ASSERT_OK(status);
58 EXPECT_EQ(signature.input().tensor_name(), "flow");
59 }
60
TEST(GetClassificationSignature,MissingSignature)61 TEST(GetClassificationSignature, MissingSignature) {
62 tensorflow::MetaGraphDef meta_graph_def;
63 Signatures signatures;
64 signatures.mutable_default_signature();
65 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
66 .mutable_any_list()
67 ->add_value()
68 ->PackFrom(signatures);
69
70 ClassificationSignature signature;
71 const Status status = GetClassificationSignature(meta_graph_def, &signature);
72 ASSERT_FALSE(status.ok());
73 EXPECT_TRUE(str_util::StrContains(status.error_message(),
74 "Expected a classification signature"))
75 << status.error_message();
76 }
77
TEST(GetClassificationSignature,WrongSignatureType)78 TEST(GetClassificationSignature, WrongSignatureType) {
79 tensorflow::MetaGraphDef meta_graph_def;
80 Signatures signatures;
81 signatures.mutable_default_signature()->mutable_regression_signature();
82 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
83 .mutable_any_list()
84 ->add_value()
85 ->PackFrom(signatures);
86
87 ClassificationSignature signature;
88 const Status status = GetClassificationSignature(meta_graph_def, &signature);
89 ASSERT_FALSE(status.ok());
90 EXPECT_TRUE(str_util::StrContains(status.error_message(),
91 "Expected a classification signature"))
92 << status.error_message();
93 }
94
TEST(GetNamedClassificationSignature,Basic)95 TEST(GetNamedClassificationSignature, Basic) {
96 tensorflow::MetaGraphDef meta_graph_def;
97 Signatures signatures;
98 ClassificationSignature* input_signature =
99 (*signatures.mutable_named_signatures())["foo"]
100 .mutable_classification_signature();
101 input_signature->mutable_input()->set_tensor_name("flow");
102 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
103 .mutable_any_list()
104 ->add_value()
105 ->PackFrom(signatures);
106
107 ClassificationSignature signature;
108 const Status status =
109 GetNamedClassificationSignature("foo", meta_graph_def, &signature);
110 TF_ASSERT_OK(status);
111 EXPECT_EQ(signature.input().tensor_name(), "flow");
112 }
113
TEST(GetNamedClassificationSignature,MissingSignature)114 TEST(GetNamedClassificationSignature, MissingSignature) {
115 tensorflow::MetaGraphDef meta_graph_def;
116 Signatures signatures;
117 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
118 .mutable_any_list()
119 ->add_value()
120 ->PackFrom(signatures);
121
122 ClassificationSignature signature;
123 const Status status =
124 GetNamedClassificationSignature("foo", meta_graph_def, &signature);
125 ASSERT_FALSE(status.ok());
126 EXPECT_TRUE(str_util::StrContains(status.error_message(),
127 "Missing signature named \"foo\""))
128 << status.error_message();
129 }
130
TEST(GetNamedClassificationSignature,WrongSignatureType)131 TEST(GetNamedClassificationSignature, WrongSignatureType) {
132 tensorflow::MetaGraphDef meta_graph_def;
133 Signatures signatures;
134 (*signatures.mutable_named_signatures())["foo"]
135 .mutable_regression_signature();
136 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
137 .mutable_any_list()
138 ->add_value()
139 ->PackFrom(signatures);
140
141 ClassificationSignature signature;
142 const Status status =
143 GetNamedClassificationSignature("foo", meta_graph_def, &signature);
144 ASSERT_FALSE(status.ok());
145 EXPECT_TRUE(str_util::StrContains(
146 status.error_message(),
147 "Expected a classification signature for name \"foo\""))
148 << status.error_message();
149 }
150
TEST(GetRegressionSignature,Basic)151 TEST(GetRegressionSignature, Basic) {
152 tensorflow::MetaGraphDef meta_graph_def;
153 Signatures signatures;
154 RegressionSignature* input_signature =
155 signatures.mutable_default_signature()->mutable_regression_signature();
156 input_signature->mutable_input()->set_tensor_name("flow");
157 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
158 .mutable_any_list()
159 ->add_value()
160 ->PackFrom(signatures);
161
162 RegressionSignature signature;
163 const Status status = GetRegressionSignature(meta_graph_def, &signature);
164 TF_ASSERT_OK(status);
165 EXPECT_EQ(signature.input().tensor_name(), "flow");
166 }
167
TEST(GetRegressionSignature,MissingSignature)168 TEST(GetRegressionSignature, MissingSignature) {
169 tensorflow::MetaGraphDef meta_graph_def;
170 Signatures signatures;
171 signatures.mutable_default_signature();
172 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
173 .mutable_any_list()
174 ->add_value()
175 ->PackFrom(signatures);
176
177 RegressionSignature signature;
178 const Status status = GetRegressionSignature(meta_graph_def, &signature);
179 ASSERT_FALSE(status.ok());
180 EXPECT_TRUE(str_util::StrContains(status.error_message(),
181 "Expected a regression signature"))
182 << status.error_message();
183 }
184
TEST(GetRegressionSignature,WrongSignatureType)185 TEST(GetRegressionSignature, WrongSignatureType) {
186 tensorflow::MetaGraphDef meta_graph_def;
187 Signatures signatures;
188 signatures.mutable_default_signature()->mutable_classification_signature();
189 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
190 .mutable_any_list()
191 ->add_value()
192 ->PackFrom(signatures);
193
194 RegressionSignature signature;
195 const Status status = GetRegressionSignature(meta_graph_def, &signature);
196 ASSERT_FALSE(status.ok());
197 EXPECT_TRUE(str_util::StrContains(status.error_message(),
198 "Expected a regression signature"))
199 << status.error_message();
200 }
201
TEST(GetNamedSignature,Basic)202 TEST(GetNamedSignature, Basic) {
203 tensorflow::MetaGraphDef meta_graph_def;
204 Signatures signatures;
205 ClassificationSignature* input_signature =
206 (*signatures.mutable_named_signatures())["foo"]
207 .mutable_classification_signature();
208 input_signature->mutable_input()->set_tensor_name("flow");
209 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
210 .mutable_any_list()
211 ->add_value()
212 ->PackFrom(signatures);
213
214 Signature signature;
215 const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
216 TF_ASSERT_OK(status);
217 EXPECT_EQ(signature.classification_signature().input().tensor_name(), "flow");
218 }
219
TEST(GetNamedSignature,MissingSignature)220 TEST(GetNamedSignature, MissingSignature) {
221 tensorflow::MetaGraphDef meta_graph_def;
222 Signatures signatures;
223 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
224 .mutable_any_list()
225 ->add_value()
226 ->PackFrom(signatures);
227
228 Signature signature;
229 const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
230 ASSERT_FALSE(status.ok());
231 EXPECT_TRUE(str_util::StrContains(status.error_message(),
232 "Missing signature named \"foo\""))
233 << status.error_message();
234 }
235
236 // MockSession used to test input and output interactions with a
237 // tensorflow::Session.
238 struct MockSession : public tensorflow::Session {
239 ~MockSession() override = default;
240
Createtensorflow::serving::__anond1cc78060111::MockSession241 Status Create(const GraphDef& graph) override {
242 return errors::Unimplemented("Not implemented for mock.");
243 }
244
Extendtensorflow::serving::__anond1cc78060111::MockSession245 Status Extend(const GraphDef& graph) override {
246 return errors::Unimplemented("Not implemented for mock.");
247 }
248
249 // Sets the input and output arguments.
Runtensorflow::serving::__anond1cc78060111::MockSession250 Status Run(const std::vector<std::pair<string, Tensor>>& inputs_arg,
251 const std::vector<string>& output_tensor_names_arg,
252 const std::vector<string>& target_node_names_arg,
253 std::vector<Tensor>* outputs_arg) override {
254 inputs = inputs_arg;
255 output_tensor_names = output_tensor_names_arg;
256 target_node_names = target_node_names_arg;
257 *outputs_arg = outputs;
258 return status;
259 }
260
Closetensorflow::serving::__anond1cc78060111::MockSession261 Status Close() override {
262 return errors::Unimplemented("Not implemented for mock.");
263 }
264
ListDevicestensorflow::serving::__anond1cc78060111::MockSession265 Status ListDevices(std::vector<DeviceAttributes>* response) override {
266 return errors::Unimplemented("Not implemented for mock.");
267 }
268
269 // Arguments stored on a Run call.
270 std::vector<std::pair<string, Tensor>> inputs;
271 std::vector<string> output_tensor_names;
272 std::vector<string> target_node_names;
273
274 // Output argument set by Run; should be set before calling.
275 std::vector<Tensor> outputs;
276
277 // Return value for Run; should be set before calling.
278 Status status;
279 };
280
281 constexpr char kInputName[] = "in:0";
282 constexpr char kClassesName[] = "classes:0";
283 constexpr char kScoresName[] = "scores:0";
284
285 class RunClassificationTest : public ::testing::Test {
286 public:
SetUp()287 void SetUp() override {
288 signature_.mutable_input()->set_tensor_name(kInputName);
289 signature_.mutable_classes()->set_tensor_name(kClassesName);
290 signature_.mutable_scores()->set_tensor_name(kScoresName);
291 }
292
293 protected:
294 ClassificationSignature signature_;
295 Tensor input_tensor_;
296 Tensor classes_tensor_;
297 Tensor scores_tensor_;
298 MockSession session_;
299 };
300
TEST_F(RunClassificationTest,Basic)301 TEST_F(RunClassificationTest, Basic) {
302 input_tensor_ = test::AsTensor<int>({99});
303 session_.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({2})};
304 const Status status = RunClassification(signature_, input_tensor_, &session_,
305 &classes_tensor_, &scores_tensor_);
306
307 // Validate outputs.
308 TF_ASSERT_OK(status);
309 test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_);
310 test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_);
311
312 // Validate inputs.
313 ASSERT_EQ(1, session_.inputs.size());
314 EXPECT_EQ(kInputName, session_.inputs[0].first);
315 test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
316 session_.inputs[0].second);
317
318 ASSERT_EQ(2, session_.output_tensor_names.size());
319 EXPECT_EQ(kClassesName, session_.output_tensor_names[0]);
320 EXPECT_EQ(kScoresName, session_.output_tensor_names[1]);
321 }
322
TEST_F(RunClassificationTest,ClassesOnly)323 TEST_F(RunClassificationTest, ClassesOnly) {
324 input_tensor_ = test::AsTensor<int>({99});
325 session_.outputs = {test::AsTensor<int>({3})};
326 const Status status = RunClassification(signature_, input_tensor_, &session_,
327 &classes_tensor_, nullptr);
328
329 // Validate outputs.
330 TF_ASSERT_OK(status);
331 test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_);
332
333 // Validate inputs.
334 ASSERT_EQ(1, session_.inputs.size());
335 EXPECT_EQ(kInputName, session_.inputs[0].first);
336 test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
337 session_.inputs[0].second);
338
339 ASSERT_EQ(1, session_.output_tensor_names.size());
340 EXPECT_EQ(kClassesName, session_.output_tensor_names[0]);
341 }
342
TEST_F(RunClassificationTest,ScoresOnly)343 TEST_F(RunClassificationTest, ScoresOnly) {
344 input_tensor_ = test::AsTensor<int>({99});
345 session_.outputs = {test::AsTensor<int>({2})};
346 const Status status = RunClassification(signature_, input_tensor_, &session_,
347 nullptr, &scores_tensor_);
348
349 // Validate outputs.
350 TF_ASSERT_OK(status);
351 test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_);
352
353 // Validate inputs.
354 ASSERT_EQ(1, session_.inputs.size());
355 EXPECT_EQ(kInputName, session_.inputs[0].first);
356 test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
357 session_.inputs[0].second);
358
359 ASSERT_EQ(1, session_.output_tensor_names.size());
360 EXPECT_EQ(kScoresName, session_.output_tensor_names[0]);
361 }
362
TEST(RunClassification,RunNotOk)363 TEST(RunClassification, RunNotOk) {
364 ClassificationSignature signature;
365 signature.mutable_input()->set_tensor_name("in:0");
366 signature.mutable_classes()->set_tensor_name("classes:0");
367 Tensor input_tensor = test::AsTensor<int>({99});
368 MockSession session;
369 session.status = errors::DataLoss("Data is gone");
370 Tensor classes_tensor;
371 const Status status = RunClassification(signature, input_tensor, &session,
372 &classes_tensor, nullptr);
373 ASSERT_FALSE(status.ok());
374 EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone"))
375 << status.error_message();
376 }
377
TEST(RunClassification,TooManyOutputs)378 TEST(RunClassification, TooManyOutputs) {
379 ClassificationSignature signature;
380 signature.mutable_input()->set_tensor_name("in:0");
381 signature.mutable_classes()->set_tensor_name("classes:0");
382 Tensor input_tensor = test::AsTensor<int>({99});
383 MockSession session;
384 session.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({4})};
385
386 Tensor classes_tensor;
387 const Status status = RunClassification(signature, input_tensor, &session,
388 &classes_tensor, nullptr);
389 ASSERT_FALSE(status.ok());
390 EXPECT_TRUE(
391 str_util::StrContains(status.error_message(), "Expected 1 output"))
392 << status.error_message();
393 }
394
TEST(RunClassification,WrongBatchOutputs)395 TEST(RunClassification, WrongBatchOutputs) {
396 ClassificationSignature signature;
397 signature.mutable_input()->set_tensor_name("in:0");
398 signature.mutable_classes()->set_tensor_name("classes:0");
399 Tensor input_tensor = test::AsTensor<int>({99, 100});
400 MockSession session;
401 session.outputs = {test::AsTensor<int>({3})};
402
403 Tensor classes_tensor;
404 const Status status = RunClassification(signature, input_tensor, &session,
405 &classes_tensor, nullptr);
406 ASSERT_FALSE(status.ok());
407 EXPECT_TRUE(
408 str_util::StrContains(status.error_message(),
409 "Input batch size did not match output batch size"))
410 << status.error_message();
411 }
412
413 constexpr char kRegressionsName[] = "regressions:0";
414
415 class RunRegressionTest : public ::testing::Test {
416 public:
SetUp()417 void SetUp() override {
418 signature_.mutable_input()->set_tensor_name(kInputName);
419 signature_.mutable_output()->set_tensor_name(kRegressionsName);
420 }
421
422 protected:
423 RegressionSignature signature_;
424 Tensor input_tensor_;
425 Tensor output_tensor_;
426 MockSession session_;
427 };
428
TEST_F(RunRegressionTest,Basic)429 TEST_F(RunRegressionTest, Basic) {
430 input_tensor_ = test::AsTensor<int>({99, 100});
431 session_.outputs = {test::AsTensor<float>({1, 2})};
432 const Status status =
433 RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
434
435 // Validate outputs.
436 TF_ASSERT_OK(status);
437 test::ExpectTensorEqual<float>(test::AsTensor<float>({1, 2}), output_tensor_);
438
439 // Validate inputs.
440 ASSERT_EQ(1, session_.inputs.size());
441 EXPECT_EQ(kInputName, session_.inputs[0].first);
442 test::ExpectTensorEqual<int>(test::AsTensor<int>({99, 100}),
443 session_.inputs[0].second);
444
445 ASSERT_EQ(1, session_.output_tensor_names.size());
446 EXPECT_EQ(kRegressionsName, session_.output_tensor_names[0]);
447 }
448
TEST_F(RunRegressionTest,RunNotOk)449 TEST_F(RunRegressionTest, RunNotOk) {
450 input_tensor_ = test::AsTensor<int>({99});
451 session_.status = errors::DataLoss("Data is gone");
452 const Status status =
453 RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
454 ASSERT_FALSE(status.ok());
455 EXPECT_TRUE(str_util::StrContains(status.error_message(), "Data is gone"))
456 << status.error_message();
457 }
458
TEST_F(RunRegressionTest,MismatchedSizeForBatchInputAndOutput)459 TEST_F(RunRegressionTest, MismatchedSizeForBatchInputAndOutput) {
460 input_tensor_ = test::AsTensor<int>({99, 100});
461 session_.outputs = {test::AsTensor<float>({3})};
462
463 const Status status =
464 RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
465 ASSERT_FALSE(status.ok());
466 EXPECT_TRUE(
467 str_util::StrContains(status.error_message(),
468 "Input batch size did not match output batch size"))
469 << status.error_message();
470 }
471
TEST(SetAndGetSignatures,RoundTrip)472 TEST(SetAndGetSignatures, RoundTrip) {
473 tensorflow::MetaGraphDef meta_graph_def;
474 Signatures signatures;
475 signatures.mutable_default_signature()
476 ->mutable_classification_signature()
477 ->mutable_input()
478 ->set_tensor_name("in:0");
479 TF_ASSERT_OK(SetSignatures(signatures, &meta_graph_def));
480 Signatures read_signatures;
481 TF_ASSERT_OK(GetSignatures(meta_graph_def, &read_signatures));
482
483 EXPECT_EQ("in:0", read_signatures.default_signature()
484 .classification_signature()
485 .input()
486 .tensor_name());
487 }
488
TEST(GetSignatures,MissingSignature)489 TEST(GetSignatures, MissingSignature) {
490 tensorflow::MetaGraphDef meta_graph_def;
491 Signatures read_signatures;
492 const auto status = GetSignatures(meta_graph_def, &read_signatures);
493 EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
494 EXPECT_TRUE(
495 str_util::StrContains(status.error_message(), "Expected exactly one"))
496 << status.error_message();
497 }
498
TEST(GetSignatures,WrongProtoInAny)499 TEST(GetSignatures, WrongProtoInAny) {
500 tensorflow::MetaGraphDef meta_graph_def;
501 auto& collection_def = *(meta_graph_def.mutable_collection_def());
502 auto* any =
503 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
504 // Put an unexpected type into the Signatures Any.
505 any->PackFrom(TensorBinding());
506 Signatures read_signatures;
507 const auto status = GetSignatures(meta_graph_def, &read_signatures);
508 EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
509 EXPECT_TRUE(str_util::StrContains(status.error_message(),
510 "Expected Any type_url for: "
511 "tensorflow.serving.Signatures"))
512 << status.error_message();
513 }
514
TEST(GetSignatures,JunkInAny)515 TEST(GetSignatures, JunkInAny) {
516 tensorflow::MetaGraphDef meta_graph_def;
517 auto& collection_def = *(meta_graph_def.mutable_collection_def());
518 auto* any =
519 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
520 // Create a valid Any then corrupt it.
521 any->PackFrom(Signatures());
522 any->set_value("junk junk");
523 Signatures read_signatures;
524 const auto status = GetSignatures(meta_graph_def, &read_signatures);
525 EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
526 EXPECT_TRUE(str_util::StrContains(status.error_message(), "Failed to unpack"))
527 << status.error_message();
528 }
529
TEST(GetSignatures,DefaultAndNamedTogetherOK)530 TEST(GetSignatures, DefaultAndNamedTogetherOK) {
531 tensorflow::MetaGraphDef meta_graph_def;
532 auto& collection_def = *(meta_graph_def.mutable_collection_def());
533 auto* any =
534 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
535 Signatures signatures;
536 signatures.mutable_default_signature()
537 ->mutable_classification_signature()
538 ->mutable_input()
539 ->set_tensor_name("in:0");
540 ClassificationSignature* input_signature =
541 (*signatures.mutable_named_signatures())["foo"]
542 .mutable_classification_signature();
543 input_signature->mutable_input()->set_tensor_name("flow");
544
545 any->PackFrom(signatures);
546 Signatures read_signatures;
547 const auto status = GetSignatures(meta_graph_def, &read_signatures);
548
549 EXPECT_TRUE(status.ok());
550 }
551
552 // Check that we only have one 'Signatures' entry in the collection_def map.
553 // Note that each such object can have multiple named_signatures inside of it.
TEST(GetSignatures,MultipleSignaturesNotOK)554 TEST(GetSignatures, MultipleSignaturesNotOK) {
555 tensorflow::MetaGraphDef meta_graph_def;
556 auto& collection_def = *(meta_graph_def.mutable_collection_def());
557 auto* any =
558 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
559 Signatures signatures;
560 signatures.mutable_default_signature()
561 ->mutable_classification_signature()
562 ->mutable_input()
563 ->set_tensor_name("in:0");
564 any->PackFrom(signatures);
565
566 // Add another signatures object.
567 any =
568 collection_def[kSignaturesKey].mutable_any_list()->mutable_value()->Add();
569 any->PackFrom(signatures);
570 Signatures read_signatures;
571 const auto status = GetSignatures(meta_graph_def, &read_signatures);
572 EXPECT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
573 EXPECT_TRUE(
574 str_util::StrContains(status.error_message(), "Expected exactly one"))
575 << status.error_message();
576 }
577
578 // GenericSignature test fixture that contains a signature initialized with two
579 // bound Tensors.
580 class GenericSignatureTest : public ::testing::Test {
581 protected:
GenericSignatureTest()582 GenericSignatureTest() {
583 TensorBinding binding;
584 binding.set_tensor_name("graph_A");
585 signature_.mutable_map()->insert({"logical_A", binding});
586
587 binding.set_tensor_name("graph_B");
588 signature_.mutable_map()->insert({"logical_B", binding});
589 }
590
591 // GenericSignature that contains two bound Tensors.
592 GenericSignature signature_;
593 };
594
595 // GenericSignature tests.
596
TEST_F(GenericSignatureTest,GetGenericSignatureBasic)597 TEST_F(GenericSignatureTest, GetGenericSignatureBasic) {
598 Signature expected_signature;
599 expected_signature.mutable_generic_signature()->MergeFrom(signature_);
600
601 tensorflow::MetaGraphDef meta_graph_def;
602 Signatures signatures;
603 signatures.mutable_named_signatures()->insert(
604 {"generic_bindings", expected_signature});
605 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
606 .mutable_any_list()
607 ->add_value()
608 ->PackFrom(signatures);
609
610 GenericSignature actual_signature;
611 TF_ASSERT_OK(GetGenericSignature("generic_bindings", meta_graph_def,
612 &actual_signature));
613 ASSERT_EQ("graph_A", actual_signature.map().at("logical_A").tensor_name());
614 ASSERT_EQ("graph_B", actual_signature.map().at("logical_B").tensor_name());
615 }
616
TEST(GetGenericSignature,MissingSignature)617 TEST(GetGenericSignature, MissingSignature) {
618 tensorflow::MetaGraphDef meta_graph_def;
619 Signatures signatures;
620 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
621 .mutable_any_list()
622 ->add_value()
623 ->PackFrom(signatures);
624
625 GenericSignature signature;
626 const Status status =
627 GetGenericSignature("generic_bindings", meta_graph_def, &signature);
628 ASSERT_FALSE(status.ok());
629 EXPECT_TRUE(HasSubstr(status.error_message(),
630 "Missing generic signature named \"generic_bindings\""))
631 << status.error_message();
632 }
633
TEST(GetGenericSignature,WrongSignatureType)634 TEST(GetGenericSignature, WrongSignatureType) {
635 tensorflow::MetaGraphDef meta_graph_def;
636 Signatures signatures;
637 (*signatures.mutable_named_signatures())["generic_bindings"]
638 .mutable_regression_signature();
639 (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
640 .mutable_any_list()
641 ->add_value()
642 ->PackFrom(signatures);
643
644 GenericSignature signature;
645 const Status status =
646 GetGenericSignature("generic_bindings", meta_graph_def, &signature);
647 ASSERT_FALSE(status.ok());
648 EXPECT_TRUE(str_util::StrContains(status.error_message(),
649 "Expected a generic signature:"))
650 << status.error_message();
651 }
652
653 // BindGeneric Tests.
654
TEST_F(GenericSignatureTest,BindGenericInputsBasic)655 TEST_F(GenericSignatureTest, BindGenericInputsBasic) {
656 const std::vector<std::pair<string, Tensor>> inputs = {
657 {"logical_A", test::AsTensor<float>({-1.0})},
658 {"logical_B", test::AsTensor<float>({-2.0})}};
659
660 std::vector<std::pair<string, Tensor>> bound_inputs;
661 TF_ASSERT_OK(BindGenericInputs(signature_, inputs, &bound_inputs));
662
663 EXPECT_EQ("graph_A", bound_inputs[0].first);
664 EXPECT_EQ("graph_B", bound_inputs[1].first);
665 test::ExpectTensorEqual<float>(test::AsTensor<float>({-1.0}),
666 bound_inputs[0].second);
667 test::ExpectTensorEqual<float>(test::AsTensor<float>({-2.0}),
668 bound_inputs[1].second);
669 }
670
TEST_F(GenericSignatureTest,BindGenericInputsMissingBinding)671 TEST_F(GenericSignatureTest, BindGenericInputsMissingBinding) {
672 const std::vector<std::pair<string, Tensor>> inputs = {
673 {"logical_A", test::AsTensor<float>({-42.0})},
674 {"logical_MISSING", test::AsTensor<float>({-43.0})}};
675
676 std::vector<std::pair<string, Tensor>> bound_inputs;
677 const Status status = BindGenericInputs(signature_, inputs, &bound_inputs);
678 ASSERT_FALSE(status.ok());
679 }
680
TEST_F(GenericSignatureTest,BindGenericNamesBasic)681 TEST_F(GenericSignatureTest, BindGenericNamesBasic) {
682 const std::vector<string> input_names = {"logical_B", "logical_A"};
683 std::vector<string> bound_names;
684 TF_ASSERT_OK(BindGenericNames(signature_, input_names, &bound_names));
685
686 EXPECT_EQ("graph_B", bound_names[0]);
687 EXPECT_EQ("graph_A", bound_names[1]);
688 }
689
TEST_F(GenericSignatureTest,BindGenericNamesMissingBinding)690 TEST_F(GenericSignatureTest, BindGenericNamesMissingBinding) {
691 const std::vector<string> input_names = {"logical_B", "logical_MISSING"};
692 std::vector<string> bound_names;
693 const Status status = BindGenericNames(signature_, input_names, &bound_names);
694 ASSERT_FALSE(status.ok());
695 }
696
697 } // namespace
698 } // namespace serving
699 } // namespace tensorflow
700