1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/contrib/session_bundle/signature.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "google/protobuf/any.pb.h"
23 #include "tensorflow/contrib/session_bundle/manifest.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/core/platform/protobuf_internal.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/protobuf/meta_graph.pb.h"
31 #include "tensorflow/core/public/session.h"
32
33 namespace tensorflow {
34 namespace serving {
35 namespace {
36
37 // Returns OK if the input and output batch sizes match.
BatchSizesMatch(const Tensor & input,const Tensor & output)38 Status BatchSizesMatch(const Tensor& input, const Tensor& output) {
39 // Ensure the number of outputs match the number of inputs.
40 if (input.dim_size(0) != output.dim_size(0)) {
41 return errors::Internal(strings::StrCat(
42 "Input batch size did not match output batch size: ", input.dim_size(0),
43 " vs. ", output.dim_size(0)));
44 }
45 return Status::OK();
46 }
47 } // namespace
48
GetSignatures(const tensorflow::MetaGraphDef & meta_graph_def,Signatures * signatures)49 Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
50 Signatures* signatures) {
51 const auto& collection_def = meta_graph_def.collection_def();
52 const auto it = collection_def.find(kSignaturesKey);
53 if (it == collection_def.end() || it->second.any_list().value_size() != 1) {
54 return errors::FailedPrecondition(
55 strings::StrCat("Expected exactly one signatures proto in : ",
56 DebugStringIfAvailable(meta_graph_def)));
57 }
58 const auto& any = it->second.any_list().value(0);
59 return ParseAny(any, signatures, "tensorflow.serving.Signatures");
60 }
61
SetSignatures(const Signatures & signatures,tensorflow::MetaGraphDef * meta_graph_def)62 Status SetSignatures(const Signatures& signatures,
63 tensorflow::MetaGraphDef* meta_graph_def) {
64 auto& collection_def = *(meta_graph_def->mutable_collection_def());
65 auto* any_list = collection_def[kSignaturesKey].mutable_any_list();
66 any_list->mutable_value()->Clear();
67 #ifdef TENSORFLOW_LITE_PROTOS
68 signatures.SerializeToString(
69 any_list->mutable_value()->Add()->mutable_value());
70 #else
71 any_list->mutable_value()->Add()->PackFrom(signatures);
72 #endif
73 return Status::OK();
74 }
75
GetClassificationSignature(const tensorflow::MetaGraphDef & meta_graph_def,ClassificationSignature * signature)76 Status GetClassificationSignature(
77 const tensorflow::MetaGraphDef& meta_graph_def,
78 ClassificationSignature* signature) {
79 Signatures signatures;
80 TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
81 if (!signatures.has_default_signature()) {
82 return errors::FailedPrecondition(
83 strings::StrCat("Expected a default signature in: ",
84 DebugStringIfAvailable(signatures)));
85 }
86 if (!signatures.default_signature().has_classification_signature()) {
87 return errors::FailedPrecondition(strings::StrCat(
88 "Expected a classification signature in: ",
89 DebugStringIfAvailable(signatures.default_signature())));
90 }
91 *signature = signatures.default_signature().classification_signature();
92 return Status::OK();
93 }
94
GetNamedClassificationSignature(const string & name,const tensorflow::MetaGraphDef & meta_graph_def,ClassificationSignature * signature)95 Status GetNamedClassificationSignature(
96 const string& name, const tensorflow::MetaGraphDef& meta_graph_def,
97 ClassificationSignature* signature) {
98 Signatures signatures;
99 TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
100 const auto& it = signatures.named_signatures().find(name);
101 if (it == signatures.named_signatures().end()) {
102 return errors::NotFound(
103 strings::StrCat("Missing signature named \"", name,
104 "\" in: ", DebugStringIfAvailable(signatures)));
105 }
106 if (!it->second.has_classification_signature()) {
107 return errors::FailedPrecondition(
108 strings::StrCat("Expected a classification signature for name \"", name,
109 "\" in: ", DebugStringIfAvailable(it->second)));
110 }
111 *signature = it->second.classification_signature();
112 return Status::OK();
113 }
114
RunClassification(const ClassificationSignature & signature,const Tensor & input,Session * session,Tensor * classes,Tensor * scores)115 Status RunClassification(const ClassificationSignature& signature,
116 const Tensor& input, Session* session, Tensor* classes,
117 Tensor* scores) {
118 std::vector<string> output_tensor_names;
119 if (classes) {
120 output_tensor_names.push_back(signature.classes().tensor_name());
121 }
122 if (scores) {
123 output_tensor_names.push_back(signature.scores().tensor_name());
124 }
125 // Run the graph with our inputs and outputs.
126 std::vector<Tensor> outputs;
127 const Status run_status =
128 session->Run({{signature.input().tensor_name(), input}},
129 output_tensor_names, {}, &outputs);
130 if (!run_status.ok()) {
131 return run_status;
132 }
133 // Ensure the output is shaped how we expect.
134 // There should be one string Tensor of shape,
135 // [batch_size, num_recommendations].
136 if (outputs.size() != output_tensor_names.size()) {
137 return errors::Internal(
138 strings::StrCat("Expected ", output_tensor_names.size(),
139 " output tensor(s). Got: ", outputs.size()));
140 }
141 if (classes) {
142 *classes = outputs[0];
143 TF_RETURN_IF_ERROR(BatchSizesMatch(input, *classes));
144 }
145 if (scores) {
146 *scores = outputs[classes ? 1 : 0];
147 TF_RETURN_IF_ERROR(BatchSizesMatch(input, *scores));
148 }
149 return Status::OK();
150 }
151
GetRegressionSignature(const tensorflow::MetaGraphDef & meta_graph_def,RegressionSignature * signature)152 Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def,
153 RegressionSignature* signature) {
154 Signatures signatures;
155 TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
156 if (!signatures.has_default_signature()) {
157 return errors::FailedPrecondition(
158 strings::StrCat("Expected a default signature in: ",
159 DebugStringIfAvailable(signatures)));
160 }
161 if (!signatures.default_signature().has_regression_signature()) {
162 return errors::FailedPrecondition(strings::StrCat(
163 "Expected a regression signature in: ",
164 DebugStringIfAvailable(signatures.default_signature())));
165 }
166 *signature = signatures.default_signature().regression_signature();
167 return Status::OK();
168 }
169
RunRegression(const RegressionSignature & signature,const Tensor & regression_input,Session * session,Tensor * regression_output)170 Status RunRegression(const RegressionSignature& signature,
171 const Tensor& regression_input, Session* session,
172 Tensor* regression_output) {
173 std::vector<string> output_tensor_names;
174 if (regression_output) {
175 output_tensor_names.push_back(signature.output().tensor_name());
176 }
177 // Run the graph with our inputs and outputs.
178 std::vector<Tensor> outputs;
179 const Status run_status =
180 session->Run({{signature.input().tensor_name(), regression_input}},
181 output_tensor_names, {}, &outputs);
182 if (!run_status.ok()) {
183 return run_status;
184 }
185 // Ensure the regression score output is shaped how we expect.
186 // There should be one float Tensor of shape,
187 // [batch_size, num_recommendations].
188 if (outputs.size() != output_tensor_names.size()) {
189 return errors::Internal(
190 strings::StrCat("Expected ", output_tensor_names.size(),
191 " output tensor(s). Got: ", outputs.size()));
192 }
193 if (regression_output) {
194 *regression_output = outputs[0];
195 TF_RETURN_IF_ERROR(BatchSizesMatch(regression_input, *regression_output));
196 }
197 return Status::OK();
198 }
199
GetGenericSignature(const string & name,const tensorflow::MetaGraphDef & meta_graph_def,GenericSignature * signature)200 Status GetGenericSignature(const string& name,
201 const tensorflow::MetaGraphDef& meta_graph_def,
202 GenericSignature* signature) {
203 Signatures signatures;
204 TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
205 const auto& it = signatures.named_signatures().find(name);
206 if (it == signatures.named_signatures().end()) {
207 return errors::InvalidArgument(
208 strings::StrCat("Missing generic signature named \"", name, "\" in ",
209 DebugStringIfAvailable(signatures)));
210 }
211 if (!it->second.has_generic_signature()) {
212 return errors::InvalidArgument(strings::StrCat(
213 "Expected a generic signature: ", DebugStringIfAvailable(it->second)));
214 }
215 *signature = it->second.generic_signature();
216 return Status::OK();
217 }
218
GetDefaultSignature(const tensorflow::MetaGraphDef & meta_graph_def,Signature * default_signature)219 Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def,
220 Signature* default_signature) {
221 Signatures signatures;
222 TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
223 *default_signature = signatures.default_signature();
224 return Status::OK();
225 }
226
GetNamedSignature(const string & name,const tensorflow::MetaGraphDef & meta_graph_def,Signature * signature)227 Status GetNamedSignature(const string& name,
228 const tensorflow::MetaGraphDef& meta_graph_def,
229 Signature* signature) {
230 Signatures signatures;
231 TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
232 const auto& it = signatures.named_signatures().find(name);
233 if (it == signatures.named_signatures().end()) {
234 return errors::NotFound(
235 strings::StrCat("Missing signature named \"", name,
236 "\" in: ", DebugStringIfAvailable(signatures)));
237 }
238 *signature = it->second;
239 return Status::OK();
240 }
241
BindGenericInputs(const GenericSignature & signature,const std::vector<std::pair<string,Tensor>> & inputs,std::vector<std::pair<string,Tensor>> * bound_inputs)242 Status BindGenericInputs(const GenericSignature& signature,
243 const std::vector<std::pair<string, Tensor>>& inputs,
244 std::vector<std::pair<string, Tensor>>* bound_inputs) {
245 const protobuf::Map<string, serving::TensorBinding>& bindings =
246 signature.map();
247
248 for (const auto& entry : inputs) {
249 const auto mapped = bindings.find(entry.first);
250 if (mapped == bindings.end()) {
251 return errors::NotFound(
252 strings::StrCat("Could not find generic binding for: ", entry.first));
253 }
254 bound_inputs->push_back({mapped->second.tensor_name(), entry.second});
255 }
256 return Status::OK();
257 }
258
BindGenericNames(const GenericSignature & signature,const std::vector<string> & input_names,std::vector<string> * bound_names)259 Status BindGenericNames(const GenericSignature& signature,
260 const std::vector<string>& input_names,
261 std::vector<string>* bound_names) {
262 const protobuf::Map<string, serving::TensorBinding>& bindings =
263 signature.map();
264
265 for (const string& entry : input_names) {
266 const auto mapped = bindings.find(entry);
267 if (mapped == bindings.end()) {
268 return errors::NotFound(
269 strings::StrCat("Could not find generic binding for: ", entry));
270 }
271 bound_names->push_back(mapped->second.tensor_name());
272 }
273 return Status::OK();
274 }
275
276 } // namespace serving
277 } // namespace tensorflow
278