1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/session_bundle/signature.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "google/protobuf/any.pb.h"
23 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/protobuf_internal.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/protobuf/meta_graph.pb.h"
31 #include "tensorflow/core/public/session.h"
32 
33 namespace tensorflow {
34 namespace serving {
35 namespace {
36 
37 // Returns OK if the input and output batch sizes match.
BatchSizesMatch(const Tensor & input,const Tensor & output)38 Status BatchSizesMatch(const Tensor& input, const Tensor& output) {
39   // Ensure the number of outputs match the number of inputs.
40   if (input.dim_size(0) != output.dim_size(0)) {
41     return errors::Internal(strings::StrCat(
42         "Input batch size did not match output batch size: ", input.dim_size(0),
43         " vs. ", output.dim_size(0)));
44   }
45   return Status::OK();
46 }
47 }  // namespace
48 
GetSignatures(const tensorflow::MetaGraphDef & meta_graph_def,Signatures * signatures)49 Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
50                      Signatures* signatures) {
51   const auto& collection_def = meta_graph_def.collection_def();
52   const auto it = collection_def.find(kSignaturesKey);
53   if (it == collection_def.end() || it->second.any_list().value_size() != 1) {
54     return errors::FailedPrecondition(
55         strings::StrCat("Expected exactly one signatures proto in : ",
56                         DebugStringIfAvailable(meta_graph_def)));
57   }
58   const auto& any = it->second.any_list().value(0);
59   return ParseAny(any, signatures, "tensorflow.serving.Signatures");
60 }
61 
SetSignatures(const Signatures & signatures,tensorflow::MetaGraphDef * meta_graph_def)62 Status SetSignatures(const Signatures& signatures,
63                      tensorflow::MetaGraphDef* meta_graph_def) {
64   auto& collection_def = *(meta_graph_def->mutable_collection_def());
65   auto* any_list = collection_def[kSignaturesKey].mutable_any_list();
66   any_list->mutable_value()->Clear();
67 #ifdef TENSORFLOW_LITE_PROTOS
68   signatures.SerializeToString(
69       any_list->mutable_value()->Add()->mutable_value());
70 #else
71   any_list->mutable_value()->Add()->PackFrom(signatures);
72 #endif
73   return Status::OK();
74 }
75 
GetClassificationSignature(const tensorflow::MetaGraphDef & meta_graph_def,ClassificationSignature * signature)76 Status GetClassificationSignature(
77     const tensorflow::MetaGraphDef& meta_graph_def,
78     ClassificationSignature* signature) {
79   Signatures signatures;
80   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
81   if (!signatures.has_default_signature()) {
82     return errors::FailedPrecondition(
83         strings::StrCat("Expected a default signature in: ",
84                         DebugStringIfAvailable(signatures)));
85   }
86   if (!signatures.default_signature().has_classification_signature()) {
87     return errors::FailedPrecondition(strings::StrCat(
88         "Expected a classification signature in: ",
89         DebugStringIfAvailable(signatures.default_signature())));
90   }
91   *signature = signatures.default_signature().classification_signature();
92   return Status::OK();
93 }
94 
GetNamedClassificationSignature(const string & name,const tensorflow::MetaGraphDef & meta_graph_def,ClassificationSignature * signature)95 Status GetNamedClassificationSignature(
96     const string& name, const tensorflow::MetaGraphDef& meta_graph_def,
97     ClassificationSignature* signature) {
98   Signatures signatures;
99   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
100   const auto& it = signatures.named_signatures().find(name);
101   if (it == signatures.named_signatures().end()) {
102     return errors::NotFound(
103         strings::StrCat("Missing signature named \"", name,
104                         "\" in: ", DebugStringIfAvailable(signatures)));
105   }
106   if (!it->second.has_classification_signature()) {
107     return errors::FailedPrecondition(
108         strings::StrCat("Expected a classification signature for name \"", name,
109                         "\" in: ", DebugStringIfAvailable(it->second)));
110   }
111   *signature = it->second.classification_signature();
112   return Status::OK();
113 }
114 
RunClassification(const ClassificationSignature & signature,const Tensor & input,Session * session,Tensor * classes,Tensor * scores)115 Status RunClassification(const ClassificationSignature& signature,
116                          const Tensor& input, Session* session, Tensor* classes,
117                          Tensor* scores) {
118   std::vector<string> output_tensor_names;
119   if (classes) {
120     output_tensor_names.push_back(signature.classes().tensor_name());
121   }
122   if (scores) {
123     output_tensor_names.push_back(signature.scores().tensor_name());
124   }
125   // Run the graph with our inputs and outputs.
126   std::vector<Tensor> outputs;
127   const Status run_status =
128       session->Run({{signature.input().tensor_name(), input}},
129                    output_tensor_names, {}, &outputs);
130   if (!run_status.ok()) {
131     return run_status;
132   }
133   // Ensure the output is shaped how we expect.
134   // There should be one string Tensor of shape,
135   //   [batch_size, num_recommendations].
136   if (outputs.size() != output_tensor_names.size()) {
137     return errors::Internal(
138         strings::StrCat("Expected ", output_tensor_names.size(),
139                         " output tensor(s).  Got: ", outputs.size()));
140   }
141   if (classes) {
142     *classes = outputs[0];
143     TF_RETURN_IF_ERROR(BatchSizesMatch(input, *classes));
144   }
145   if (scores) {
146     *scores = outputs[classes ? 1 : 0];
147     TF_RETURN_IF_ERROR(BatchSizesMatch(input, *scores));
148   }
149   return Status::OK();
150 }
151 
GetRegressionSignature(const tensorflow::MetaGraphDef & meta_graph_def,RegressionSignature * signature)152 Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def,
153                               RegressionSignature* signature) {
154   Signatures signatures;
155   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
156   if (!signatures.has_default_signature()) {
157     return errors::FailedPrecondition(
158         strings::StrCat("Expected a default signature in: ",
159                         DebugStringIfAvailable(signatures)));
160   }
161   if (!signatures.default_signature().has_regression_signature()) {
162     return errors::FailedPrecondition(strings::StrCat(
163         "Expected a regression signature in: ",
164         DebugStringIfAvailable(signatures.default_signature())));
165   }
166   *signature = signatures.default_signature().regression_signature();
167   return Status::OK();
168 }
169 
RunRegression(const RegressionSignature & signature,const Tensor & regression_input,Session * session,Tensor * regression_output)170 Status RunRegression(const RegressionSignature& signature,
171                      const Tensor& regression_input, Session* session,
172                      Tensor* regression_output) {
173   std::vector<string> output_tensor_names;
174   if (regression_output) {
175     output_tensor_names.push_back(signature.output().tensor_name());
176   }
177   // Run the graph with our inputs and outputs.
178   std::vector<Tensor> outputs;
179   const Status run_status =
180       session->Run({{signature.input().tensor_name(), regression_input}},
181                    output_tensor_names, {}, &outputs);
182   if (!run_status.ok()) {
183     return run_status;
184   }
185   // Ensure the regression score output is shaped how we expect.
186   // There should be one float Tensor of shape,
187   //   [batch_size, num_recommendations].
188   if (outputs.size() != output_tensor_names.size()) {
189     return errors::Internal(
190         strings::StrCat("Expected ", output_tensor_names.size(),
191                         " output tensor(s).  Got: ", outputs.size()));
192   }
193   if (regression_output) {
194     *regression_output = outputs[0];
195     TF_RETURN_IF_ERROR(BatchSizesMatch(regression_input, *regression_output));
196   }
197   return Status::OK();
198 }
199 
GetGenericSignature(const string & name,const tensorflow::MetaGraphDef & meta_graph_def,GenericSignature * signature)200 Status GetGenericSignature(const string& name,
201                            const tensorflow::MetaGraphDef& meta_graph_def,
202                            GenericSignature* signature) {
203   Signatures signatures;
204   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
205   const auto& it = signatures.named_signatures().find(name);
206   if (it == signatures.named_signatures().end()) {
207     return errors::InvalidArgument(
208         strings::StrCat("Missing generic signature named \"", name, "\" in ",
209                         DebugStringIfAvailable(signatures)));
210   }
211   if (!it->second.has_generic_signature()) {
212     return errors::InvalidArgument(strings::StrCat(
213         "Expected a generic signature: ", DebugStringIfAvailable(it->second)));
214   }
215   *signature = it->second.generic_signature();
216   return Status::OK();
217 }
218 
GetDefaultSignature(const tensorflow::MetaGraphDef & meta_graph_def,Signature * default_signature)219 Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def,
220                            Signature* default_signature) {
221   Signatures signatures;
222   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
223   *default_signature = signatures.default_signature();
224   return Status::OK();
225 }
226 
GetNamedSignature(const string & name,const tensorflow::MetaGraphDef & meta_graph_def,Signature * signature)227 Status GetNamedSignature(const string& name,
228                          const tensorflow::MetaGraphDef& meta_graph_def,
229                          Signature* signature) {
230   Signatures signatures;
231   TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
232   const auto& it = signatures.named_signatures().find(name);
233   if (it == signatures.named_signatures().end()) {
234     return errors::NotFound(
235         strings::StrCat("Missing signature named \"", name,
236                         "\" in: ", DebugStringIfAvailable(signatures)));
237   }
238   *signature = it->second;
239   return Status::OK();
240 }
241 
BindGenericInputs(const GenericSignature & signature,const std::vector<std::pair<string,Tensor>> & inputs,std::vector<std::pair<string,Tensor>> * bound_inputs)242 Status BindGenericInputs(const GenericSignature& signature,
243                          const std::vector<std::pair<string, Tensor>>& inputs,
244                          std::vector<std::pair<string, Tensor>>* bound_inputs) {
245   const protobuf::Map<string, serving::TensorBinding>& bindings =
246       signature.map();
247 
248   for (const auto& entry : inputs) {
249     const auto mapped = bindings.find(entry.first);
250     if (mapped == bindings.end()) {
251       return errors::NotFound(
252           strings::StrCat("Could not find generic binding for: ", entry.first));
253     }
254     bound_inputs->push_back({mapped->second.tensor_name(), entry.second});
255   }
256   return Status::OK();
257 }
258 
BindGenericNames(const GenericSignature & signature,const std::vector<string> & input_names,std::vector<string> * bound_names)259 Status BindGenericNames(const GenericSignature& signature,
260                         const std::vector<string>& input_names,
261                         std::vector<string>* bound_names) {
262   const protobuf::Map<string, serving::TensorBinding>& bindings =
263       signature.map();
264 
265   for (const string& entry : input_names) {
266     const auto mapped = bindings.find(entry);
267     if (mapped == bindings.end()) {
268       return errors::NotFound(
269           strings::StrCat("Could not find generic binding for: ", entry));
270     }
271     bound_names->push_back(mapped->second.tensor_name());
272   }
273   return Status::OK();
274 }
275 
276 }  // namespace serving
277 }  // namespace tensorflow
278