1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
18 
19 #include "tensorflow/core/framework/allocator.h"
20 #include "tensorflow/core/framework/cost_graph.pb.h"
21 #include "tensorflow/core/framework/graph.pb.h"
22 #include "tensorflow/core/framework/step_stats.pb.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor.pb_text.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/protobuf/master.pb.h"
28 #include "tensorflow/core/protobuf/worker.pb.h"
29 
30 namespace tensorflow {
31 
32 ////////////////////////////////////////////////////////////////////////////////
33 //
34 // Wrapper classes for the `MasterService.RunStep` request message.
35 //
36 // The `RunStepRequest` message can contain potentially large tensor
37 // data as part of its `feed` submessages. Here we provide specialized
38 // wrappers that avoid copying the tensor data wherever possible.
39 //
40 // See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the
41 // protocol buffer definition.
42 //
43 ////////////////////////////////////////////////////////////////////////////////
44 
45 // Abstract interface for an immutable RunStepRequest message.
46 //
47 // This interface is typically used by server-side components in the
48 // TensorFlow master.
49 class RunStepRequestWrapper {
50  public:
~RunStepRequestWrapper()51   virtual ~RunStepRequestWrapper() {}
52 
53   // REQUIRED: session_handle must be returned by a CreateSession call
54   // to the same master service.
55   virtual const string& session_handle() const = 0;
56 
57   // Partial run handle (optional). If specified, this will be a partial run
58   // execution, run up to the specified fetches.
59   virtual const string& partial_run_handle() const = 0;
60 
61   // Tensors to be fed in the step. Each feed is a named tensor.
62   virtual size_t num_feeds() const = 0;
63   virtual const string& feed_name(size_t i) const = 0;
64 
65   // Stores the content of the feed value at index `i` in `tensor`.
66   virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0;
67   virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0;
68 
69   // Fetches. A list of tensor names. The caller expects a tensor to
70   // be returned for each fetch[i] (see RunStepResponse.tensor). The
71   // order of specified fetches does not change the execution order.
72   virtual size_t num_fetches() const = 0;
73   virtual const string& fetch_name(size_t i) const = 0;
74 
75   // Target Nodes. A list of node names. The named nodes will be run
76   // to but their outputs will not be fetched.
77   virtual size_t num_targets() const = 0;
78   virtual const string& target_name(size_t i) const = 0;
79 
80   // Options for the run call.
81   virtual const RunOptions& options() const = 0;
82 
83   // If true then some errors, e.g., execution errors that have long
84   // error messages, may return an OK RunStepResponse with the actual
85   // error saved in the status_code/status_error_message fields of the
86   // response body. This is a workaround since the RPC subsystem may
87   // truncate long metadata messages.
88   virtual bool store_errors_in_response_body() const = 0;
89 
90   virtual int64 request_id() const = 0;
91 
92   // Returns a human-readable representation of this message for debugging.
93   virtual string DebugString() const = 0;
94 
95   // Returns the wrapped data as a protocol buffer message.
96   virtual const RunStepRequest& ToProto() const = 0;
97 };
98 
99 // Abstract interface for a mutable RunStepRequest message.
100 //
101 // See `RunStepRequestWrapper` above for a description of the fields.
102 class MutableRunStepRequestWrapper : public RunStepRequestWrapper {
103  public:
104   virtual void set_session_handle(const string& handle) = 0;
105   virtual void set_partial_run_handle(const string& handle) = 0;
106   virtual void add_feed(const string& name, const Tensor& value) = 0;
107   virtual void add_fetch(const string& name) = 0;
108   virtual void add_target(const string& name) = 0;
109   virtual RunOptions* mutable_options() = 0;
110   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
111 };
112 
113 // Specialized (and mutable) wrapper for RunStep requests between a client and
114 // master in the same address space.
115 class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
116  public:
117   // RunStepRequestWrapper methods.
118   const string& session_handle() const override;
119   const string& partial_run_handle() const override;
120   size_t num_feeds() const override;
121   const string& feed_name(size_t i) const override;
122   Status FeedValue(size_t i, Tensor* out_tensor) const override;
123   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
124   size_t num_fetches() const override;
125   const string& fetch_name(size_t i) const override;
126   size_t num_targets() const override;
127   const string& target_name(size_t i) const override;
128   const RunOptions& options() const override;
129   string DebugString() const override;
130   const RunStepRequest& ToProto() const override;
131   bool store_errors_in_response_body() const override;
132   int64 request_id() const override;
133 
134   // MutableRunStepRequestWrapper methods.
135   void set_session_handle(const string& handle) override;
136   void set_partial_run_handle(const string& handle) override;
137   void add_feed(const string& name, const Tensor& value) override;
138   void add_fetch(const string& name) override;
139   void add_target(const string& name) override;
140   RunOptions* mutable_options() override;
141   void set_store_errors_in_response_body(bool store_errors) override;
142 
143  private:
144   string session_handle_;
145   string partial_run_handle_;
146   gtl::InlinedVector<std::pair<string, Tensor>, 4> feeds_;
147   gtl::InlinedVector<string, 4> fetches_;
148   gtl::InlinedVector<string, 4> targets_;
149   RunOptions options_;
150   bool store_errors_in_response_body_ = false;
151 
152   // Holds a cached and owned representation of the proto
153   // representation of this request, if needed, so that `ToProto()`
154   // can return a const RunStepRequest&.
155   // NOTE(mrry): Although calls to `ToProto()` on this class are
156   // expected to be rare, retaining ownership of the returned message
157   // makes it easier to return a reference from the proto-backed
158   // representations.
159   mutable std::unique_ptr<RunStepRequest> proto_version_;
160 };
161 
162 // Wrapper for mutable RunStep requests that uses a protobuf message.
163 //
164 // This wrapper class should be used for RunStep requests between a
165 // client and master in different address spaces.
166 class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
167  public:
168   // RunStepRequestWrapper methods.
169   const string& session_handle() const override;
170   const string& partial_run_handle() const override;
171   size_t num_feeds() const override;
172   const string& feed_name(size_t i) const override;
173   Status FeedValue(size_t i, Tensor* out_tensor) const override;
174   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
175   size_t num_fetches() const override;
176   const string& fetch_name(size_t i) const override;
177   size_t num_targets() const override;
178   const string& target_name(size_t i) const override;
179   const RunOptions& options() const override;
180   string DebugString() const override;
181   const RunStepRequest& ToProto() const override;
182   bool store_errors_in_response_body() const override;
183   int64 request_id() const override;
184 
185   // MutableRunStepRequestWrapper methods.
186   void set_session_handle(const string& handle) override;
187   void set_partial_run_handle(const string& handle) override;
188   void add_feed(const string& name, const Tensor& value) override;
189   void add_fetch(const string& name) override;
190   void add_target(const string& name) override;
191   RunOptions* mutable_options() override;
192   void set_store_errors_in_response_body(bool store_errors) override;
193 
194  private:
195   RunStepRequest request_;
196   friend class MasterInterface;
197 };
198 
199 // Wrapper for immutable RunStep requests that use a non-owned
200 // protobuf message.
201 //
202 // This interface is typically used by server-side components in the
203 // TensorFlow master, where the incoming message is a (possibly const)
204 // `RunStepRequest*`.
205 class ProtoRunStepRequest : public RunStepRequestWrapper {
206  public:
207   ProtoRunStepRequest(const RunStepRequest* request);
208 
209   // RunStepRequestWrapper methods.
210   const string& session_handle() const override;
211   const string& partial_run_handle() const override;
212   size_t num_feeds() const override;
213   const string& feed_name(size_t i) const override;
214   Status FeedValue(size_t i, Tensor* out_tensor) const override;
215   Status FeedValue(size_t i, TensorProto* out_tensor) const override;
216   size_t num_fetches() const override;
217   const string& fetch_name(size_t i) const override;
218   size_t num_targets() const override;
219   const string& target_name(size_t i) const override;
220   const RunOptions& options() const override;
221   string DebugString() const override;
222   const RunStepRequest& ToProto() const override;
223   bool store_errors_in_response_body() const override;
224   int64 request_id() const override;
225 
226  private:
227   const RunStepRequest* const request_;  // Not owned.
228 };
229 
230 ////////////////////////////////////////////////////////////////////////////////
231 //
232 // Wrapper classes for the `WorkerService.RunGraph` request message.
233 //
234 // The `RunGraphRequest` message can contain potentially large tensor
235 // data as part of its `send` submessages. Here we provide specialized
236 // wrappers that avoid copying the tensor data wherever possible.
237 //
238 // See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the
239 // protocol buffer definition.
240 //
241 ////////////////////////////////////////////////////////////////////////////////
242 
243 // Abstract interface for an immutable RunGraphRequest message.
244 //
245 // This interface is typically used by server-side components in the
246 // TensorFlow worker.
247 class RunGraphRequestWrapper {
248  public:
~RunGraphRequestWrapper()249   virtual ~RunGraphRequestWrapper() {}
250 
251   // The session handle used to register the graph. If empty, a single global
252   // namespace is used.
253   virtual const string& session_handle() const = 0;
254 
255   // Set to true if `CreateWorkerSession` was called for `session_handle`.
256   virtual bool create_worker_session_called() const = 0;
257 
258   // REQUIRED: graph_handle must be returned by a RegisterGraph call
259   // to the same WorkerService.
260   virtual const string& graph_handle() const = 0;
261 
262   // A unique ID to distinguish different runs of the same graph.
263   //
264   // The master generates a global unique `step_id` to distinguish
265   // different runs of the graph computation. Subgraphs communicate
266   // (e.g., send/recv ops) with each other using `step_id` to
267   // distinguish tensors generated by different runs.
268   virtual int64 step_id() const = 0;
269 
270   // Options for this step.
271   virtual const ExecutorOpts& exec_opts() const = 0;
272 
273   // Sends the tensors in "send" into the graph before the run.
274   virtual size_t num_sends() const = 0;
275   virtual const string& send_key(size_t i) const = 0;
276   virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0;
277 
278   // Fetches the keys into `RunGraphResponse.recv` after the run.
279   virtual size_t num_recvs() const = 0;
280   virtual const string& recv_key(size_t i) const = 0;
281 
282   // True if the RunGraphRequest is a partial run request.
283   virtual bool is_partial() const = 0;
284 
285   // True if this is the last partial run request in a sequence of requests.
286   virtual bool is_last_partial_run() const = 0;
287 
288   // If true then some errors, e.g., execution errors that have long
289   // error messages, may return an OK RunStepResponse with the actual
290   // error saved in the status_code/status_error_message fields of the
291   // response body. This is a workaround since the RPC subsystem may
292   // truncate long metadata messages.
293   virtual bool store_errors_in_response_body() const = 0;
294 
295   // Returns the wrapped data as a protocol buffer message.
296   virtual const RunGraphRequest& ToProto() const = 0;
297 };
298 
299 // Abstract interface for a mutable RunGraphRequest message.
300 //
301 // See `RunGraphRequestWrapper` above for a description of the fields.
302 class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
303  public:
304   virtual void set_session_handle(const string& handle) = 0;
305   virtual void set_create_worker_session_called(bool called) = 0;
306   virtual void set_graph_handle(const string& handle) = 0;
307   virtual void set_step_id(int64 step_id) = 0;
308   virtual ExecutorOpts* mutable_exec_opts() = 0;
309 
310   // Stores the i^{th} feed value in `run_step_request` in this
311   // request with the given `send_key`.
312   virtual Status AddSendFromRunStepRequest(
313       const RunStepRequestWrapper& run_step_request, size_t i,
314       const string& send_key) = 0;
315   virtual Status AddSendFromRunCallableRequest(
316       const RunCallableRequest& run_callable_request, size_t i,
317       const string& send_key) = 0;
318 
319   virtual void add_recv_key(const string& recv_key) = 0;
320   virtual void set_is_partial(bool is_partial) = 0;
321   virtual void set_is_last_partial_run(bool is_last_partial_run) = 0;
322   virtual void set_store_errors_in_response_body(bool store_errors) = 0;
323 };
324 
325 class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
326  public:
327   // RunGraphRequestWrapper methods.
328   const string& session_handle() const override;
329   const string& graph_handle() const override;
330   bool create_worker_session_called() const override;
331   int64 step_id() const override;
332   const ExecutorOpts& exec_opts() const override;
333   size_t num_sends() const override;
334   const string& send_key(size_t i) const override;
335   Status SendValue(size_t i, Tensor* out_tensor) const override;
336   size_t num_recvs() const override;
337   const string& recv_key(size_t i) const override;
338   bool is_partial() const override;
339   bool is_last_partial_run() const override;
340   const RunGraphRequest& ToProto() const override;
341   bool store_errors_in_response_body() const override;
342 
343   // MutableRunGraphRequestWrapper methods.
344   void set_session_handle(const string& handle) override;
345   void set_create_worker_session_called(bool called) override;
346   void set_graph_handle(const string& handle) override;
347   void set_step_id(int64 step_id) override;
348   ExecutorOpts* mutable_exec_opts() override;
349   Status AddSendFromRunStepRequest(
350       const RunStepRequestWrapper& run_step_request, size_t i,
351       const string& send_key) override;
352   Status AddSendFromRunCallableRequest(
353       const RunCallableRequest& run_callable_request, size_t i,
354       const string& send_key) override;
355   void add_recv_key(const string& recv_key) override;
356   void set_is_partial(bool is_partial) override;
357   void set_is_last_partial_run(bool is_last_partial_run) override;
358   void set_store_errors_in_response_body(bool store_errors) override;
359 
360  private:
361   string session_handle_;
362   bool create_worker_session_called_ = false;
363   string graph_handle_;
364   int64 step_id_;
365   ExecutorOpts exec_opts_;
366   gtl::InlinedVector<std::pair<string, Tensor>, 4> sends_;
367   gtl::InlinedVector<string, 4> recvs_;
368   bool is_partial_ = false;
369   bool is_last_partial_run_ = false;
370   bool store_errors_in_response_body_ = false;
371 
372   // Holds a cached and owned representation of the proto
373   // representation of this request, if needed, so that `ToProto()`
374   // can return a const RunGraphRequest&.
375   // NOTE(mrry): Although calls to `ToProto()` on this class are
376   // expected to be rare, retaining ownership of the returned message
377   // makes it easier to return a reference from the proto-backed
378   // representations.
379   mutable std::unique_ptr<RunGraphRequest> proto_version_;
380 };
381 
382 class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
383  public:
384   // RunGraphRequestWrapper methods.
385   const string& session_handle() const override;
386   bool create_worker_session_called() const override;
387   const string& graph_handle() const override;
388   int64 step_id() const override;
389   const ExecutorOpts& exec_opts() const override;
390   size_t num_sends() const override;
391   const string& send_key(size_t i) const override;
392   Status SendValue(size_t i, Tensor* out_tensor) const override;
393   size_t num_recvs() const override;
394   const string& recv_key(size_t i) const override;
395   bool is_partial() const override;
396   bool is_last_partial_run() const override;
397   bool store_errors_in_response_body() const override;
398   const RunGraphRequest& ToProto() const override;
399 
400   // MutableRunGraphRequestWrapper methods.
401   void set_session_handle(const string& handle) override;
402   void set_create_worker_session_called(bool called) override;
403   void set_graph_handle(const string& handle) override;
404   void set_step_id(int64 step_id) override;
405   ExecutorOpts* mutable_exec_opts() override;
406   Status AddSendFromRunStepRequest(
407       const RunStepRequestWrapper& run_step_request, size_t i,
408       const string& send_key) override;
409   Status AddSendFromRunCallableRequest(
410       const RunCallableRequest& run_callable_request, size_t i,
411       const string& send_key) override;
412   void add_recv_key(const string& recv_key) override;
413   void set_is_partial(bool is_partial) override;
414   void set_is_last_partial_run(bool is_last_partial_run) override;
415   void set_store_errors_in_response_body(bool store_errors) override;
416 
417  private:
418   RunGraphRequest request_;
419 };
420 
421 class ProtoRunGraphRequest : public RunGraphRequestWrapper {
422  public:
423   ProtoRunGraphRequest(const RunGraphRequest* request);
424 
425   // RunGraphRequestWrapper methods.
426   const string& session_handle() const override;
427   bool create_worker_session_called() const override;
428   const string& graph_handle() const override;
429   int64 step_id() const override;
430   const ExecutorOpts& exec_opts() const override;
431   size_t num_sends() const override;
432   const string& send_key(size_t i) const override;
433   Status SendValue(size_t i, Tensor* out_tensor) const override;
434   size_t num_recvs() const override;
435   const string& recv_key(size_t i) const override;
436   bool is_partial() const override;
437   bool is_last_partial_run() const override;
438   bool store_errors_in_response_body() const override;
439   const RunGraphRequest& ToProto() const override;
440 
441  private:
442   const RunGraphRequest* const request_;  // Not owned.
443 };
444 
445 ////////////////////////////////////////////////////////////////////////////////
446 //
447 // Wrapper classes for the `WorkerService.RunGraph` response message.
448 //
449 // The `RunGraphResponse` message can contain potentially large tensor
450 // data as part of its `recv` submessages. Here we provide specialized
451 // wrappers that avoid copying the tensor data wherever possible.
452 //
453 // See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the
454 // protocol buffer definition.
455 //
456 ////////////////////////////////////////////////////////////////////////////////
457 
458 // Abstract interface for a mutable RunGraphResponse message.
459 //
460 // Note that there is no corresponding (immutable)
461 // RunGraphResponseWrapper class, because the RunGraphResponse object
462 // is always used as a mutable pointer.
463 class MutableRunGraphResponseWrapper {
464  public:
~MutableRunGraphResponseWrapper()465   virtual ~MutableRunGraphResponseWrapper() {}
466 
467   // A list of tensors corresponding to those requested by
468   // `RunGraphRequest.recv_key`.
469   virtual size_t num_recvs() const = 0;
470   virtual const string& recv_key(size_t i) const = 0;
471   // NOTE: The following methods may perform a destructive read, for
472   // efficiency.
473   virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0;
474   virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0;
475   virtual void AddRecv(const string& key, const Tensor& value) = 0;
476 
477   // Submessages that store performance statistics about the subgraph
478   // execution, if necessary.
479   virtual StepStats* mutable_step_stats() = 0;
480   virtual CostGraphDef* mutable_cost_graph() = 0;
481   virtual size_t num_partition_graphs() const = 0;
482   virtual GraphDef* mutable_partition_graph(size_t i) = 0;
483   virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0;
484 
485   // Returned status if requested.
486   virtual errors::Code status_code() const = 0;
487   virtual const string& status_error_message() const = 0;
488   virtual void set_status(const Status& status) = 0;
489 
490  protected:
491   // Returns a mutable protobuf message that represents the contents of
492   // this wrapper, for passing to an RPC subsystem that will populate
493   // the message.
494   //
495   // NOTE: Only `WorkerInterface` subclasses may call this method. The
496   // `InMemoryRunGraphResponse` subclass does not implement this
497   // method, and attempts to call it will fail with a fatal
498   // error. However, as long as callers always call
499   // `WorkerInterface::RunGraphAsync()` with a wrapper object returned
500   // from `WorkerInterface::CreateRunGraphResponse()` called on the
501   // *same* WorkerInterface object, this error will never trigger.
502   virtual RunGraphResponse* get_proto() = 0;
503   friend class WorkerInterface;
504 };
505 
506 class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
507  public:
508   // MutableRunGraphResponseWrapper methods.
509   size_t num_recvs() const override;
510   const string& recv_key(size_t i) const override;
511   Status RecvValue(size_t i, TensorProto* out_tensor) override;
512   Status RecvValue(size_t i, Tensor* out_tensor) override;
513   void AddRecv(const string& key, const Tensor& value) override;
514   StepStats* mutable_step_stats() override;
515   CostGraphDef* mutable_cost_graph() override;
516   size_t num_partition_graphs() const override;
517   GraphDef* mutable_partition_graph(size_t i) override;
518   void AddPartitionGraph(const GraphDef& partition_graph) override;
519   errors::Code status_code() const override;
520   const string& status_error_message() const override;
521   void set_status(const Status& status) override;
522 
523  protected:
524   // NOTE: This method is not implemented. See
525   // MutableRunGraphResponseWrapper for an explanation.
526   RunGraphResponse* get_proto() override;
527 
528  private:
529   gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_;
530   StepStats step_stats_;
531   CostGraphDef cost_graph_;
532   std::vector<GraphDef> partition_graphs_;
533   // Store the code and message separately so that they can be updated
534   // independently by setters.
535   Status status_;
536 };
537 
538 // Proto-based message wrapper for use on the client side of the RunGraph RPC.
539 class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
540  public:
541   // MutableRunGraphResponseWrapper methods.
542   size_t num_recvs() const override;
543   const string& recv_key(size_t i) const override;
544   Status RecvValue(size_t i, TensorProto* out_tensor) override;
545   Status RecvValue(size_t i, Tensor* out_tensor) override;
546   void AddRecv(const string& key, const Tensor& value) override;
547   StepStats* mutable_step_stats() override;
548   CostGraphDef* mutable_cost_graph() override;
549   size_t num_partition_graphs() const override;
550   GraphDef* mutable_partition_graph(size_t i) override;
551   void AddPartitionGraph(const GraphDef& partition_graph) override;
552   errors::Code status_code() const override;
553   const string& status_error_message() const override;
554   void set_status(const Status& status) override;
555 
556  protected:
557   RunGraphResponse* get_proto() override;
558 
559  private:
560   RunGraphResponse response_;
561 };
562 
563 // Proto-based message wrapper for use on the server side of the RunGraph RPC.
564 class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
565  public:
566   NonOwnedProtoRunGraphResponse(RunGraphResponse* response);
567 
568   // MutableRunGraphResponseWrapper methods.
569   size_t num_recvs() const override;
570   const string& recv_key(size_t i) const override;
571   Status RecvValue(size_t i, TensorProto* out_tensor) override;
572   Status RecvValue(size_t i, Tensor* out_tensor) override;
573   void AddRecv(const string& key, const Tensor& value) override;
574   StepStats* mutable_step_stats() override;
575   CostGraphDef* mutable_cost_graph() override;
576   size_t num_partition_graphs() const override;
577   GraphDef* mutable_partition_graph(size_t i) override;
578   void AddPartitionGraph(const GraphDef& partition_graph) override;
579   errors::Code status_code() const override;
580   const string& status_error_message() const override;
581   void set_status(const Status& status) override;
582 
583  protected:
584   RunGraphResponse* get_proto() override;
585 
586  private:
587   RunGraphResponse* const response_;
588 };
589 
590 ////////////////////////////////////////////////////////////////////////////////
591 //
592 // Wrapper classes for the `MasterService.RunStep` response message.
593 //
594 // The `RunStepResponse` message can contain potentially large tensor
595 // data as part of its `tensor` submessages. Here we provide specialized
596 // wrappers that avoid copying the tensor data wherever possible.
597 //
598 // See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the
599 // protocol buffer definition.
600 //
601 ////////////////////////////////////////////////////////////////////////////////
602 
603 // Abstract interface for a mutable RunStepResponse message.
604 //
605 // Note that there is no corresponding (immutable)
606 // RunStepResponseWrapper class, because the RunStepResponse object is
607 // always used as a mutable pointer.
608 class MutableRunStepResponseWrapper {
609  public:
610   virtual ~MutableRunStepResponseWrapper();
611 
612   // The values of the tensors whose fetching was requested in the
613   // RunStep call.
614   //
615   // NOTE: The order of the returned tensors may or may not match
616   // the fetch order specified in RunStepRequest.
617   virtual size_t num_tensors() const = 0;
618   virtual const string& tensor_name(size_t i) const = 0;
619   virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0;
620 
621   // Stores the i^{th} recv value in `run_graph_response` in this
622   // response with the given `name`.
623   virtual Status AddTensorFromRunGraphResponse(
624       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
625       size_t i) = 0;
626 
627   // Returned metadata if requested in the options.
628   virtual const RunMetadata& metadata() const = 0;
629   virtual RunMetadata* mutable_metadata() = 0;
630 
631   // Returned status if requested.
632   virtual errors::Code status_code() const = 0;
633   virtual const string& status_error_message() const = 0;
634   virtual void set_status(const Status& status) = 0;
635 
636  protected:
637   // Returns a mutable protobuf message that represents the contents of
638   // this wrapper, for passing to an RPC subsystem that will populate
639   // the message.
640   //
641   // NOTE: Only `MasterInterface` subclasses may call this method. The
642   // `InMemoryRunStepResponse` subclass does not implement this
643   // method, and attempts to call it will fail with a fatal
644   // error. However, as long as callers always call
645   // `MasterInterface::RunStep()` with a wrapper object returned
646   // from `MasterInterface::CreateRunStepResponse()` called on the
647   // *same* MasterInterface object, this error will never trigger.
648   virtual RunStepResponse* get_proto() = 0;
649   friend class MasterInterface;
650 };
651 
652 class InMemoryRunStepResponse : public MutableRunStepResponseWrapper {
653  public:
654   // MutableRunStepResponseWrapper methods.
655   size_t num_tensors() const override;
656   const string& tensor_name(size_t i) const override;
657   Status TensorValue(size_t i, Tensor* out_tensor) const override;
658   Status AddTensorFromRunGraphResponse(
659       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
660       size_t i) override;
661   const RunMetadata& metadata() const override;
662   RunMetadata* mutable_metadata() override;
663   errors::Code status_code() const override;
664   const string& status_error_message() const override;
665   void set_status(const Status& status) override;
666 
667  protected:
668   // NOTE: This method is not implemented. See
669   // MutableRunGraphResponseWrapper for an explanation.
670   RunStepResponse* get_proto() override;
671 
672  private:
673   gtl::InlinedVector<std::pair<string, Tensor>, 4> tensors_;
674   RunMetadata metadata_;
675   // Store the code and message separately so that they can be updated
676   // independently by setters.
677   Status status_;
678 };
679 
680 // Proto-based message wrapper for use on the client side of the RunStep RPC.
681 class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
682  public:
683   // MutableRunStepResponseWrapper methods.
684   size_t num_tensors() const override;
685   const string& tensor_name(size_t i) const override;
686   Status TensorValue(size_t i, Tensor* out_tensor) const override;
687   Status AddTensorFromRunGraphResponse(
688       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
689       size_t i) override;
690   const RunMetadata& metadata() const override;
691   RunMetadata* mutable_metadata() override;
692   errors::Code status_code() const override;
693   const string& status_error_message() const override;
694   void set_status(const Status& status) override;
695 
696  protected:
697   RunStepResponse* get_proto() override;
698 
699  private:
700   RunStepResponse response_;
701 };
702 
703 // Proto-based message wrapper for use on the server side of the RunStep RPC.
704 class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
705  public:
706   NonOwnedProtoRunStepResponse(RunStepResponse* response);
707 
708   // MutableRunStepResponseWrapper methods.
709   size_t num_tensors() const override;
710   const string& tensor_name(size_t i) const override;
711   Status TensorValue(size_t i, Tensor* out_tensor) const override;
712   Status AddTensorFromRunGraphResponse(
713       const string& name, MutableRunGraphResponseWrapper* run_graph_response,
714       size_t i) override;
715   const RunMetadata& metadata() const override;
716   RunMetadata* mutable_metadata() override;
717   errors::Code status_code() const override;
718   const string& status_error_message() const override;
719   void set_status(const Status& status) override;
720 
721  protected:
722   RunStepResponse* get_proto() override;
723 
724  private:
725   RunStepResponse* response_;  // Not owned.
726 };
727 
728 }  // namespace tensorflow
729 
730 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
731