1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
17 
18 #include "tensorflow/core/framework/cost_graph.pb.h"
19 #include "tensorflow/core/framework/step_stats.pb.h"
20 #include "tensorflow/core/framework/tensor.pb.h"
21 #include "tensorflow/core/protobuf/config.pb.h"
22 #include "tensorflow/core/protobuf/named_tensor.pb.h"
23 
24 namespace tensorflow {
25 
ParseTensorProtoToTensor(const TensorProto & tensor_proto,Tensor * out_tensor)26 bool ParseTensorProtoToTensor(const TensorProto& tensor_proto,
27                               Tensor* out_tensor) {
28   if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
29     Tensor parsed(tensor_proto.dtype());
30     if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
31       *out_tensor = parsed;
32       return true;
33     }
34   }
35   return false;
36 }
37 
session_handle() const38 const string& InMemoryRunStepRequest::session_handle() const {
39   return session_handle_;
40 }
41 
set_session_handle(const string & handle)42 void InMemoryRunStepRequest::set_session_handle(const string& handle) {
43   session_handle_ = handle;
44 }
45 
partial_run_handle() const46 const string& InMemoryRunStepRequest::partial_run_handle() const {
47   return partial_run_handle_;
48 }
49 
set_partial_run_handle(const string & handle)50 void InMemoryRunStepRequest::set_partial_run_handle(const string& handle) {
51   partial_run_handle_ = handle;
52 }
53 
num_feeds() const54 size_t InMemoryRunStepRequest::num_feeds() const { return feeds_.size(); }
feed_name(size_t i) const55 const string& InMemoryRunStepRequest::feed_name(size_t i) const {
56   return feeds_[i].first;
57 }
58 
FeedValue(size_t i,Tensor * out_tensor) const59 Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
60   *out_tensor = feeds_[i].second;
61   return Status::OK();
62 }
63 
FeedValue(size_t i,TensorProto * out_tensor) const64 Status InMemoryRunStepRequest::FeedValue(size_t i,
65                                          TensorProto* out_tensor) const {
66   feeds_[i].second.AsProtoTensorContent(out_tensor);
67   return Status::OK();
68 }
69 
add_feed(const string & name,const Tensor & value)70 void InMemoryRunStepRequest::add_feed(const string& name, const Tensor& value) {
71   feeds_.emplace_back(name, value);
72 }
73 
num_fetches() const74 size_t InMemoryRunStepRequest::num_fetches() const { return fetches_.size(); }
fetch_name(size_t i) const75 const string& InMemoryRunStepRequest::fetch_name(size_t i) const {
76   return fetches_[i];
77 }
add_fetch(const string & name)78 void InMemoryRunStepRequest::add_fetch(const string& name) {
79   fetches_.push_back(name);
80 }
81 
num_targets() const82 size_t InMemoryRunStepRequest::num_targets() const { return targets_.size(); }
target_name(size_t i) const83 const string& InMemoryRunStepRequest::target_name(size_t i) const {
84   return targets_[i];
85 }
add_target(const string & name)86 void InMemoryRunStepRequest::add_target(const string& name) {
87   targets_.push_back(name);
88 }
89 
options() const90 const RunOptions& InMemoryRunStepRequest::options() const { return options_; }
91 
mutable_options()92 RunOptions* InMemoryRunStepRequest::mutable_options() { return &options_; }
93 
store_errors_in_response_body() const94 bool InMemoryRunStepRequest::store_errors_in_response_body() const {
95   return store_errors_in_response_body_;
96 }
97 
request_id() const98 int64 InMemoryRunStepRequest::request_id() const {
99   return 0;  // no need to track request id for local version.
100 }
101 
set_store_errors_in_response_body(bool store_errors)102 void InMemoryRunStepRequest::set_store_errors_in_response_body(
103     bool store_errors) {
104   store_errors_in_response_body_ = store_errors;
105 }
106 
DebugString() const107 string InMemoryRunStepRequest::DebugString() const {
108   return ToProto().DebugString();
109 }
110 
ToProto() const111 const RunStepRequest& InMemoryRunStepRequest::ToProto() const {
112   if (!proto_version_) {
113     proto_version_.reset(new RunStepRequest);
114     proto_version_->set_session_handle(session_handle());
115     proto_version_->set_partial_run_handle(partial_run_handle());
116     for (size_t i = 0; i < num_feeds(); ++i) {
117       auto feed = proto_version_->add_feed();
118       feed->set_name(feed_name(i));
119       feeds_[i].second.AsProtoTensorContent(feed->mutable_tensor());
120     }
121     for (size_t i = 0; i < num_fetches(); ++i) {
122       proto_version_->add_fetch(fetch_name(i));
123     }
124     for (size_t i = 0; i < num_targets(); ++i) {
125       proto_version_->add_target(target_name(i));
126     }
127     *proto_version_->mutable_options() = options();
128   }
129   return *proto_version_;
130 }
131 
session_handle() const132 const string& MutableProtoRunStepRequest::session_handle() const {
133   return request_.session_handle();
134 }
set_session_handle(const string & handle)135 void MutableProtoRunStepRequest::set_session_handle(const string& handle) {
136   request_.set_session_handle(handle);
137 }
138 
partial_run_handle() const139 const string& MutableProtoRunStepRequest::partial_run_handle() const {
140   return request_.partial_run_handle();
141 }
set_partial_run_handle(const string & handle)142 void MutableProtoRunStepRequest::set_partial_run_handle(const string& handle) {
143   request_.set_partial_run_handle(handle);
144 }
145 
num_feeds() const146 size_t MutableProtoRunStepRequest::num_feeds() const {
147   return request_.feed_size();
148 }
feed_name(size_t i) const149 const string& MutableProtoRunStepRequest::feed_name(size_t i) const {
150   return request_.feed(i).name();
151 }
FeedValue(size_t i,Tensor * out_tensor) const152 Status MutableProtoRunStepRequest::FeedValue(size_t i,
153                                              Tensor* out_tensor) const {
154   if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) {
155     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
156   } else {
157     return Status::OK();
158   }
159 }
160 
FeedValue(size_t i,TensorProto * out_tensor) const161 Status MutableProtoRunStepRequest::FeedValue(size_t i,
162                                              TensorProto* out_tensor) const {
163   *out_tensor = request_.feed(i).tensor();
164   return Status::OK();
165 }
166 
add_feed(const string & name,const Tensor & value)167 void MutableProtoRunStepRequest::add_feed(const string& name,
168                                           const Tensor& value) {
169   NamedTensorProto* feed = request_.add_feed();
170   feed->set_name(name);
171   TensorProto* value_proto = feed->mutable_tensor();
172   value.AsProtoTensorContent(value_proto);
173 }
174 
num_fetches() const175 size_t MutableProtoRunStepRequest::num_fetches() const {
176   return request_.fetch_size();
177 }
178 
fetch_name(size_t i) const179 const string& MutableProtoRunStepRequest::fetch_name(size_t i) const {
180   return request_.fetch(i);
181 }
add_fetch(const string & name)182 void MutableProtoRunStepRequest::add_fetch(const string& name) {
183   request_.add_fetch(name);
184 }
185 
num_targets() const186 size_t MutableProtoRunStepRequest::num_targets() const {
187   return request_.target_size();
188 }
189 
target_name(size_t i) const190 const string& MutableProtoRunStepRequest::target_name(size_t i) const {
191   return request_.target(i);
192 }
193 
add_target(const string & name)194 void MutableProtoRunStepRequest::add_target(const string& name) {
195   request_.add_target(name);
196 }
197 
options() const198 const RunOptions& MutableProtoRunStepRequest::options() const {
199   return request_.options();
200 }
201 
mutable_options()202 RunOptions* MutableProtoRunStepRequest::mutable_options() {
203   return request_.mutable_options();
204 }
205 
store_errors_in_response_body() const206 bool MutableProtoRunStepRequest::store_errors_in_response_body() const {
207   return request_.store_errors_in_response_body();
208 }
209 
set_store_errors_in_response_body(bool store_errors)210 void MutableProtoRunStepRequest::set_store_errors_in_response_body(
211     bool store_errors) {
212   request_.set_store_errors_in_response_body(store_errors);
213 }
214 
request_id() const215 int64 MutableProtoRunStepRequest::request_id() const {
216   return request_.request_id();
217 }
218 
DebugString() const219 string MutableProtoRunStepRequest::DebugString() const {
220   return request_.DebugString();
221 }
222 
ToProto() const223 const RunStepRequest& MutableProtoRunStepRequest::ToProto() const {
224   return request_;
225 }
226 
ProtoRunStepRequest(const RunStepRequest * request)227 ProtoRunStepRequest::ProtoRunStepRequest(const RunStepRequest* request)
228     : request_(request) {}
229 
session_handle() const230 const string& ProtoRunStepRequest::session_handle() const {
231   return request_->session_handle();
232 }
233 
partial_run_handle() const234 const string& ProtoRunStepRequest::partial_run_handle() const {
235   return request_->partial_run_handle();
236 }
237 
num_feeds() const238 size_t ProtoRunStepRequest::num_feeds() const { return request_->feed_size(); }
239 
feed_name(size_t i) const240 const string& ProtoRunStepRequest::feed_name(size_t i) const {
241   return request_->feed(i).name();
242 }
243 
FeedValue(size_t i,Tensor * out_tensor) const244 Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const {
245   if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) {
246     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
247   } else {
248     return Status::OK();
249   }
250 }
251 
FeedValue(size_t i,TensorProto * out_tensor) const252 Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const {
253   *out_tensor = request_->feed(i).tensor();
254   return Status::OK();
255 }
256 
num_fetches() const257 size_t ProtoRunStepRequest::num_fetches() const {
258   return request_->fetch_size();
259 }
260 
fetch_name(size_t i) const261 const string& ProtoRunStepRequest::fetch_name(size_t i) const {
262   return request_->fetch(i);
263 }
264 
num_targets() const265 size_t ProtoRunStepRequest::num_targets() const {
266   return request_->target_size();
267 }
268 
target_name(size_t i) const269 const string& ProtoRunStepRequest::target_name(size_t i) const {
270   return request_->target(i);
271 }
272 
options() const273 const RunOptions& ProtoRunStepRequest::options() const {
274   return request_->options();
275 }
276 
store_errors_in_response_body() const277 bool ProtoRunStepRequest::store_errors_in_response_body() const {
278   return request_->store_errors_in_response_body();
279 }
280 
request_id() const281 int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); }
282 
DebugString() const283 string ProtoRunStepRequest::DebugString() const {
284   return request_->DebugString();
285 }
286 
ToProto() const287 const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
288 
session_handle() const289 const string& InMemoryRunGraphRequest::session_handle() const {
290   return session_handle_;
291 }
292 
create_worker_session_called() const293 bool InMemoryRunGraphRequest::create_worker_session_called() const {
294   return create_worker_session_called_;
295 }
296 
set_session_handle(const string & handle)297 void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
298   session_handle_ = handle;
299 }
300 
set_create_worker_session_called(bool called)301 void InMemoryRunGraphRequest::set_create_worker_session_called(bool called) {
302   create_worker_session_called_ = called;
303 }
304 
graph_handle() const305 const string& InMemoryRunGraphRequest::graph_handle() const {
306   return graph_handle_;
307 }
308 
set_graph_handle(const string & handle)309 void InMemoryRunGraphRequest::set_graph_handle(const string& handle) {
310   graph_handle_ = handle;
311 }
312 
step_id() const313 int64 InMemoryRunGraphRequest::step_id() const { return step_id_; }
314 
set_step_id(int64 step_id)315 void InMemoryRunGraphRequest::set_step_id(int64 step_id) { step_id_ = step_id; }
316 
exec_opts() const317 const ExecutorOpts& InMemoryRunGraphRequest::exec_opts() const {
318   return exec_opts_;
319 }
320 
mutable_exec_opts()321 ExecutorOpts* InMemoryRunGraphRequest::mutable_exec_opts() {
322   return &exec_opts_;
323 }
324 
num_sends() const325 size_t InMemoryRunGraphRequest::num_sends() const { return sends_.size(); }
326 
send_key(size_t i) const327 const string& InMemoryRunGraphRequest::send_key(size_t i) const {
328   return sends_[i].first;
329 }
330 
SendValue(size_t i,Tensor * out_tensor) const331 Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
332   *out_tensor = sends_[i].second;
333   return Status::OK();
334 }
335 
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)336 Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
337     const RunStepRequestWrapper& run_step_request, size_t i,
338     const string& send_key) {
339   Tensor tensor;
340   TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, &tensor));
341   sends_.emplace_back(send_key, std::move(tensor));
342   return Status::OK();
343 }
344 
345 // TODO(b/74355905): Add a specialized implementation that avoids
346 // copying the tensor when at least two of the {client, master,
347 // worker} are in the same process.
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)348 Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
349     const RunCallableRequest& run_callable_request, size_t i,
350     const string& send_key) {
351   Tensor tensor;
352   if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
353     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
354   }
355   sends_.emplace_back(send_key, std::move(tensor));
356   return Status::OK();
357 }
358 
num_recvs() const359 size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
360 
recv_key(size_t i) const361 const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
362   return recvs_[i];
363 }
364 
add_recv_key(const string & recv_key)365 void InMemoryRunGraphRequest::add_recv_key(const string& recv_key) {
366   recvs_.push_back(recv_key);
367 }
368 
is_partial() const369 bool InMemoryRunGraphRequest::is_partial() const { return is_partial_; }
370 
set_is_partial(bool is_partial)371 void InMemoryRunGraphRequest::set_is_partial(bool is_partial) {
372   is_partial_ = is_partial;
373 }
374 
is_last_partial_run() const375 bool InMemoryRunGraphRequest::is_last_partial_run() const {
376   return is_last_partial_run_;
377 }
378 
set_is_last_partial_run(bool is_last_partial_run)379 void InMemoryRunGraphRequest::set_is_last_partial_run(
380     bool is_last_partial_run) {
381   is_last_partial_run_ = is_last_partial_run;
382 }
383 
store_errors_in_response_body() const384 bool InMemoryRunGraphRequest::store_errors_in_response_body() const {
385   return store_errors_in_response_body_;
386 }
387 
set_store_errors_in_response_body(bool store_errors)388 void InMemoryRunGraphRequest::set_store_errors_in_response_body(
389     bool store_errors) {
390   store_errors_in_response_body_ = store_errors;
391 }
392 
request_id() const393 int64 InMemoryRunGraphRequest::request_id() const { return request_id_; }
394 
set_request_id(int64 request_id)395 void InMemoryRunGraphRequest::set_request_id(int64 request_id) {
396   request_id_ = request_id;
397 }
398 
ToProto() const399 const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
400   if (!proto_version_) {
401     proto_version_.reset(new RunGraphRequest);
402     proto_version_->set_session_handle(session_handle());
403     proto_version_->set_create_worker_session_called(
404         create_worker_session_called());
405     proto_version_->set_graph_handle(graph_handle());
406     proto_version_->set_step_id(step_id());
407     *proto_version_->mutable_exec_opts() = exec_opts();
408     for (size_t i = 0; i < num_sends(); ++i) {
409       auto send = proto_version_->add_send();
410       send->set_name(send_key(i));
411       sends_[i].second.AsProtoTensorContent(send->mutable_tensor());
412     }
413     for (size_t i = 0; i < num_recvs(); ++i) {
414       proto_version_->add_recv_key(recv_key(i));
415     }
416     proto_version_->set_is_partial(is_partial());
417     proto_version_->set_is_last_partial_run(is_last_partial_run());
418   }
419   proto_version_->set_store_errors_in_response_body(
420       store_errors_in_response_body_);
421   proto_version_->set_request_id(request_id_);
422   return *proto_version_;
423 }
424 
session_handle() const425 const string& MutableProtoRunGraphRequest::session_handle() const {
426   return request_.session_handle();
427 }
428 
set_session_handle(const string & handle)429 void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
430   request_.set_session_handle(handle);
431 }
432 
create_worker_session_called() const433 bool MutableProtoRunGraphRequest::create_worker_session_called() const {
434   return request_.create_worker_session_called();
435 }
436 
set_create_worker_session_called(bool called)437 void MutableProtoRunGraphRequest::set_create_worker_session_called(
438     bool called) {
439   request_.set_create_worker_session_called(called);
440 }
441 
graph_handle() const442 const string& MutableProtoRunGraphRequest::graph_handle() const {
443   return request_.graph_handle();
444 }
445 
set_graph_handle(const string & handle)446 void MutableProtoRunGraphRequest::set_graph_handle(const string& handle) {
447   request_.set_graph_handle(handle);
448 }
449 
step_id() const450 int64 MutableProtoRunGraphRequest::step_id() const {
451   return request_.step_id();
452 }
453 
set_step_id(int64 step_id)454 void MutableProtoRunGraphRequest::set_step_id(int64 step_id) {
455   request_.set_step_id(step_id);
456 }
457 
exec_opts() const458 const ExecutorOpts& MutableProtoRunGraphRequest::exec_opts() const {
459   return request_.exec_opts();
460 }
461 
mutable_exec_opts()462 ExecutorOpts* MutableProtoRunGraphRequest::mutable_exec_opts() {
463   return request_.mutable_exec_opts();
464 }
465 
num_sends() const466 size_t MutableProtoRunGraphRequest::num_sends() const {
467   return request_.send_size();
468 }
469 
send_key(size_t i) const470 const string& MutableProtoRunGraphRequest::send_key(size_t i) const {
471   return request_.send(i).name();
472 }
473 
SendValue(size_t i,Tensor * out_tensor) const474 Status MutableProtoRunGraphRequest::SendValue(size_t i,
475                                               Tensor* out_tensor) const {
476   if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) {
477     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
478   } else {
479     return Status::OK();
480   }
481 }
482 
AddSendFromRunStepRequest(const RunStepRequestWrapper & run_step_request,size_t i,const string & send_key)483 Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
484     const RunStepRequestWrapper& run_step_request, size_t i,
485     const string& send_key) {
486   NamedTensorProto* send = request_.add_send();
487   send->set_name(send_key);
488   TF_RETURN_IF_ERROR(run_step_request.FeedValue(i, send->mutable_tensor()));
489   return Status::OK();
490 }
491 
492 // TODO(b/74355905): Add a specialized implementation that avoids
493 // copying the tensor when at least two of the {client, master,
494 // worker} are in the same process.
AddSendFromRunCallableRequest(const RunCallableRequest & run_callable_request,size_t i,const string & send_key)495 Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
496     const RunCallableRequest& run_callable_request, size_t i,
497     const string& send_key) {
498   NamedTensorProto* send = request_.add_send();
499   send->set_name(send_key);
500   *send->mutable_tensor() = run_callable_request.feed(i);
501   return Status::OK();
502 }
503 
num_recvs() const504 size_t MutableProtoRunGraphRequest::num_recvs() const {
505   return request_.recv_key_size();
506 }
507 
recv_key(size_t i) const508 const string& MutableProtoRunGraphRequest::recv_key(size_t i) const {
509   return request_.recv_key(i);
510 }
511 
add_recv_key(const string & recv_key)512 void MutableProtoRunGraphRequest::add_recv_key(const string& recv_key) {
513   request_.add_recv_key(recv_key);
514 }
515 
is_partial() const516 bool MutableProtoRunGraphRequest::is_partial() const {
517   return request_.is_partial();
518 }
519 
set_is_partial(bool is_partial)520 void MutableProtoRunGraphRequest::set_is_partial(bool is_partial) {
521   request_.set_is_partial(is_partial);
522 }
523 
is_last_partial_run() const524 bool MutableProtoRunGraphRequest::is_last_partial_run() const {
525   return request_.is_last_partial_run();
526 }
527 
set_is_last_partial_run(bool is_last_partial_run)528 void MutableProtoRunGraphRequest::set_is_last_partial_run(
529     bool is_last_partial_run) {
530   request_.set_is_last_partial_run(is_last_partial_run);
531 }
532 
store_errors_in_response_body() const533 bool MutableProtoRunGraphRequest::store_errors_in_response_body() const {
534   return request_.store_errors_in_response_body();
535 }
536 
set_store_errors_in_response_body(bool store_errors)537 void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
538     bool store_errors) {
539   request_.set_store_errors_in_response_body(store_errors);
540 }
541 
request_id() const542 int64 MutableProtoRunGraphRequest::request_id() const {
543   return request_.request_id();
544 }
545 
set_request_id(int64 request_id)546 void MutableProtoRunGraphRequest::set_request_id(int64 request_id) {
547   request_.set_request_id(request_id);
548 }
549 
ToProto() const550 const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
551   return request_;
552 }
553 
ProtoRunGraphRequest(const RunGraphRequest * request)554 ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
555     : request_(request) {}
556 
session_handle() const557 const string& ProtoRunGraphRequest::session_handle() const {
558   return request_->session_handle();
559 }
560 
create_worker_session_called() const561 bool ProtoRunGraphRequest::create_worker_session_called() const {
562   return request_->create_worker_session_called();
563 }
564 
graph_handle() const565 const string& ProtoRunGraphRequest::graph_handle() const {
566   return request_->graph_handle();
567 }
568 
step_id() const569 int64 ProtoRunGraphRequest::step_id() const { return request_->step_id(); }
570 
exec_opts() const571 const ExecutorOpts& ProtoRunGraphRequest::exec_opts() const {
572   return request_->exec_opts();
573 }
574 
num_sends() const575 size_t ProtoRunGraphRequest::num_sends() const { return request_->send_size(); }
576 
send_key(size_t i) const577 const string& ProtoRunGraphRequest::send_key(size_t i) const {
578   return request_->send(i).name();
579 }
580 
SendValue(size_t i,Tensor * out_tensor) const581 Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const {
582   if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) {
583     return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
584   } else {
585     return Status::OK();
586   }
587 }
588 
num_recvs() const589 size_t ProtoRunGraphRequest::num_recvs() const {
590   return request_->recv_key_size();
591 }
592 
recv_key(size_t i) const593 const string& ProtoRunGraphRequest::recv_key(size_t i) const {
594   return request_->recv_key(i);
595 }
596 
is_partial() const597 bool ProtoRunGraphRequest::is_partial() const { return request_->is_partial(); }
598 
is_last_partial_run() const599 bool ProtoRunGraphRequest::is_last_partial_run() const {
600   return request_->is_last_partial_run();
601 }
602 
store_errors_in_response_body() const603 bool ProtoRunGraphRequest::store_errors_in_response_body() const {
604   return request_->store_errors_in_response_body();
605 }
606 
request_id() const607 int64 ProtoRunGraphRequest::request_id() const {
608   return request_->request_id();
609 }
610 
ToProto() const611 const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
612   return *request_;
613 }
614 
num_recvs() const615 size_t InMemoryRunGraphResponse::num_recvs() const { return recvs_.size(); }
616 
recv_key(size_t i) const617 const string& InMemoryRunGraphResponse::recv_key(size_t i) const {
618   return recvs_[i].first;
619 }
620 
RecvValue(size_t i,TensorProto * out_tensor)621 Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) {
622   recvs_[i].second.AsProtoTensorContent(out_tensor);
623   return Status::OK();
624 }
625 
RecvValue(size_t i,Tensor * out_tensor)626 Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
627   *out_tensor = recvs_[i].second;
628   return Status::OK();
629 }
630 
AddRecv(const string & key,const Tensor & value)631 void InMemoryRunGraphResponse::AddRecv(const string& key, const Tensor& value) {
632   recvs_.emplace_back(key, value);
633 }
634 
mutable_step_stats()635 StepStats* InMemoryRunGraphResponse::mutable_step_stats() {
636   return &step_stats_;
637 }
638 
mutable_cost_graph()639 CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() {
640   return &cost_graph_;
641 }
642 
status_code() const643 errors::Code InMemoryRunGraphResponse::status_code() const {
644   return status_.code();
645 }
646 
status_error_message() const647 const string& InMemoryRunGraphResponse::status_error_message() const {
648   return status_.error_message();
649 }
650 
set_status(const Status & status)651 void InMemoryRunGraphResponse::set_status(const Status& status) {
652   status_ = status;
653 }
654 
get_proto()655 RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
656   LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunGraphResponse";
657   return nullptr;
658 }
659 
num_partition_graphs() const660 size_t InMemoryRunGraphResponse::num_partition_graphs() const {
661   return partition_graphs_.size();
662 }
663 
mutable_partition_graph(size_t i)664 GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) {
665   return &partition_graphs_[i];
666 }
667 
AddPartitionGraph(const GraphDef & partition_graph)668 void InMemoryRunGraphResponse::AddPartitionGraph(
669     const GraphDef& partition_graph) {
670   partition_graphs_.push_back(partition_graph);
671 }
672 
num_recvs() const673 size_t OwnedProtoRunGraphResponse::num_recvs() const {
674   return response_.recv_size();
675 }
676 
recv_key(size_t i) const677 const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const {
678   return response_.recv(i).name();
679 }
680 
RecvValue(size_t i,TensorProto * out_tensor)681 Status OwnedProtoRunGraphResponse::RecvValue(size_t i,
682                                              TensorProto* out_tensor) {
683   out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor());
684   return Status::OK();
685 }
686 
RecvValue(size_t i,Tensor * out_tensor)687 Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
688   if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) {
689     return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
690   } else {
691     return Status::OK();
692   }
693 }
694 
AddRecv(const string & key,const Tensor & value)695 void OwnedProtoRunGraphResponse::AddRecv(const string& key,
696                                          const Tensor& value) {
697   NamedTensorProto* recv = response_.add_recv();
698   recv->set_name(key);
699   TensorProto* value_proto = recv->mutable_tensor();
700   value.AsProtoTensorContent(value_proto);
701 }
702 
mutable_step_stats()703 StepStats* OwnedProtoRunGraphResponse::mutable_step_stats() {
704   return response_.mutable_step_stats();
705 }
706 
mutable_cost_graph()707 CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
708   return response_.mutable_cost_graph();
709 }
710 
status_code() const711 errors::Code OwnedProtoRunGraphResponse::status_code() const {
712   return response_.status_code();
713 }
714 
status_error_message() const715 const string& OwnedProtoRunGraphResponse::status_error_message() const {
716   return response_.status_error_message();
717 }
718 
set_status(const Status & status)719 void OwnedProtoRunGraphResponse::set_status(const Status& status) {
720   response_.set_status_code(status.code());
721   response_.set_status_error_message(status.error_message());
722 }
723 
get_proto()724 RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
725 
num_partition_graphs() const726 size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
727   return response_.partition_graph_size();
728 }
729 
mutable_partition_graph(size_t i)730 GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
731   return response_.mutable_partition_graph(i);
732 }
733 
AddPartitionGraph(const GraphDef & partition_graph)734 void OwnedProtoRunGraphResponse::AddPartitionGraph(
735     const GraphDef& partition_graph) {
736   GraphDef* graph_def = response_.mutable_partition_graph()->Add();
737   *graph_def = partition_graph;
738 }
739 
NonOwnedProtoRunGraphResponse(RunGraphResponse * response)740 NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
741     RunGraphResponse* response)
742     : response_(response) {}
743 
num_recvs() const744 size_t NonOwnedProtoRunGraphResponse::num_recvs() const {
745   return response_->recv_size();
746 }
747 
recv_key(size_t i) const748 const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const {
749   return response_->recv(i).name();
750 }
751 
RecvValue(size_t i,TensorProto * out_tensor)752 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i,
753                                                 TensorProto* out_tensor) {
754   out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor());
755   return Status::OK();
756 }
757 
RecvValue(size_t i,Tensor * out_tensor)758 Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) {
759   if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) {
760     return errors::InvalidArgument("Invalid TensorProto for recv value ", i);
761   } else {
762     return Status::OK();
763   }
764 }
765 
AddRecv(const string & key,const Tensor & value)766 void NonOwnedProtoRunGraphResponse::AddRecv(const string& key,
767                                             const Tensor& value) {
768   NamedTensorProto* recv = response_->add_recv();
769   recv->set_name(key);
770   TensorProto* value_proto = recv->mutable_tensor();
771   value.AsProtoTensorContent(value_proto);
772 }
773 
mutable_step_stats()774 StepStats* NonOwnedProtoRunGraphResponse::mutable_step_stats() {
775   return response_->mutable_step_stats();
776 }
777 
mutable_cost_graph()778 CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() {
779   return response_->mutable_cost_graph();
780 }
781 
status_code() const782 errors::Code NonOwnedProtoRunGraphResponse::status_code() const {
783   return response_->status_code();
784 }
785 
status_error_message() const786 const string& NonOwnedProtoRunGraphResponse::status_error_message() const {
787   return response_->status_error_message();
788 }
789 
set_status(const Status & status)790 void NonOwnedProtoRunGraphResponse::set_status(const Status& status) {
791   response_->set_status_code(status.code());
792   response_->set_status_error_message(status.error_message());
793 }
794 
get_proto()795 RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
796   return response_;
797 }
798 
num_partition_graphs() const799 size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const {
800   return response_->partition_graph_size();
801 }
802 
mutable_partition_graph(size_t i)803 GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
804   return response_->mutable_partition_graph(i);
805 }
806 
AddPartitionGraph(const GraphDef & partition_graph)807 void NonOwnedProtoRunGraphResponse::AddPartitionGraph(
808     const GraphDef& partition_graph) {
809   GraphDef* graph_def = response_->add_partition_graph();
810   *graph_def = partition_graph;
811 }
812 
~MutableRunStepResponseWrapper()813 MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
814 
num_tensors() const815 size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
816 
tensor_name(size_t i) const817 const string& InMemoryRunStepResponse::tensor_name(size_t i) const {
818   return tensors_[i].first;
819 }
820 
TensorValue(size_t i,Tensor * out_tensor) const821 Status InMemoryRunStepResponse::TensorValue(size_t i,
822                                             Tensor* out_tensor) const {
823   *out_tensor = tensors_[i].second;
824   return Status::OK();
825 }
826 
metadata() const827 const RunMetadata& InMemoryRunStepResponse::metadata() const {
828   return metadata_;
829 }
830 
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * wrapper,size_t i)831 Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse(
832     const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) {
833   Tensor tensor;
834   TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor));
835   tensors_.emplace_back(name, tensor);
836   return Status::OK();
837 }
838 
mutable_metadata()839 RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; }
840 
status_code() const841 errors::Code InMemoryRunStepResponse::status_code() const {
842   return status_.code();
843 }
844 
status_error_message() const845 const string& InMemoryRunStepResponse::status_error_message() const {
846   return status_.error_message();
847 }
848 
set_status(const Status & status)849 void InMemoryRunStepResponse::set_status(const Status& status) {
850   status_ = status;
851 }
852 
get_proto()853 RunStepResponse* InMemoryRunStepResponse::get_proto() {
854   LOG(FATAL) << "Cannot get a mutable protobuf for an InMemoryRunStepResponse";
855   return nullptr;
856 }
857 
num_tensors() const858 size_t OwnedProtoRunStepResponse::num_tensors() const {
859   return response_.tensor_size();
860 }
861 
tensor_name(size_t i) const862 const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const {
863   return response_.tensor(i).name();
864 }
865 
TensorValue(size_t i,Tensor * out_tensor) const866 Status OwnedProtoRunStepResponse::TensorValue(size_t i,
867                                               Tensor* out_tensor) const {
868   if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) {
869     return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
870   } else {
871     return Status::OK();
872   }
873 }
874 
metadata() const875 const RunMetadata& OwnedProtoRunStepResponse::metadata() const {
876   return response_.metadata();
877 }
878 
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)879 Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
880     const string& name, MutableRunGraphResponseWrapper* run_graph_response,
881     size_t i) {
882   NamedTensorProto* response_tensor = response_.add_tensor();
883   response_tensor->set_name(name);
884   return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
885 }
886 
mutable_metadata()887 RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() {
888   return response_.mutable_metadata();
889 }
890 
status_code() const891 errors::Code OwnedProtoRunStepResponse::status_code() const {
892   return response_.status_code();
893 }
894 
status_error_message() const895 const string& OwnedProtoRunStepResponse::status_error_message() const {
896   return response_.status_error_message();
897 }
898 
set_status(const Status & status)899 void OwnedProtoRunStepResponse::set_status(const Status& status) {
900   response_.set_status_code(status.code());
901   response_.set_status_error_message(status.error_message());
902 }
903 
get_proto()904 RunStepResponse* OwnedProtoRunStepResponse::get_proto() { return &response_; }
905 
NonOwnedProtoRunStepResponse(RunStepResponse * response)906 NonOwnedProtoRunStepResponse::NonOwnedProtoRunStepResponse(
907     RunStepResponse* response)
908     : response_(response) {}
909 
num_tensors() const910 size_t NonOwnedProtoRunStepResponse::num_tensors() const {
911   return response_->tensor_size();
912 }
913 
tensor_name(size_t i) const914 const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const {
915   return response_->tensor(i).name();
916 }
917 
TensorValue(size_t i,Tensor * out_tensor) const918 Status NonOwnedProtoRunStepResponse::TensorValue(size_t i,
919                                                  Tensor* out_tensor) const {
920   if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) {
921     return errors::InvalidArgument("Invalid TensorProto for fetch value ", i);
922   } else {
923     return Status::OK();
924   }
925 }
926 
metadata() const927 const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const {
928   return response_->metadata();
929 }
930 
AddTensorFromRunGraphResponse(const string & name,MutableRunGraphResponseWrapper * run_graph_response,size_t i)931 Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse(
932     const string& name, MutableRunGraphResponseWrapper* run_graph_response,
933     size_t i) {
934   NamedTensorProto* response_tensor = response_->add_tensor();
935   response_tensor->set_name(name);
936   return run_graph_response->RecvValue(i, response_tensor->mutable_tensor());
937 }
938 
mutable_metadata()939 RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() {
940   return response_->mutable_metadata();
941 }
942 
status_code() const943 errors::Code NonOwnedProtoRunStepResponse::status_code() const {
944   return response_->status_code();
945 }
946 
status_error_message() const947 const string& NonOwnedProtoRunStepResponse::status_error_message() const {
948   return response_->status_error_message();
949 }
950 
set_status(const Status & status)951 void NonOwnedProtoRunStepResponse::set_status(const Status& status) {
952   response_->set_status_code(status.code());
953   response_->set_status_error_message(status.error_message());
954 }
955 
get_proto()956 RunStepResponse* NonOwnedProtoRunStepResponse::get_proto() { return response_; }
957 
958 }  // namespace tensorflow
959