1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
18 
19 #include <functional>
20 
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
23 #include "tensorflow/core/lib/core/notification.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/protobuf/worker.pb.h"
27 
28 namespace tensorflow {
29 
30 // Status callback.
31 typedef std::function<void(const Status&)> StatusCallback;
32 
33 // Custom decoder for a response to RecvTensorAsync.
34 class TensorResponse;
35 
36 // Interface for talking with the TensorFlow Worker service.
37 class WorkerInterface {
38  public:
39   virtual void GetStatusAsync(const GetStatusRequest* request,
40                               GetStatusResponse* response,
41                               StatusCallback done) = 0;
42 
43   virtual void CreateWorkerSessionAsync(
44       const CreateWorkerSessionRequest* request,
45       CreateWorkerSessionResponse* response, StatusCallback done) = 0;
46 
47   virtual void DeleteWorkerSessionAsync(
48       CallOptions* opts, const DeleteWorkerSessionRequest* request,
49       DeleteWorkerSessionResponse* response, StatusCallback done) = 0;
50 
51   virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
52                                   RegisterGraphResponse* response,
53                                   StatusCallback done) = 0;
54 
55   virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
56                                     DeregisterGraphResponse* response,
57                                     StatusCallback done) = 0;
58 
59   virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
60                              MutableRunGraphResponseWrapper* repsonse,
61                              StatusCallback done) = 0;
62 
RunGraphAsync(CallOptions * opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)63   virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
64                              RunGraphResponse* response, StatusCallback done) {
65     // TODO(mrry): Convert this to std::bind/std::move if the overhead
66     // of std::function copying becomes too much.
67     RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
68     MutableRunGraphResponseWrapper* wrapped_response =
69         new NonOwnedProtoRunGraphResponse(response);
70     RunGraphAsync(opts, wrapped_request, wrapped_response,
71                   [wrapped_request, wrapped_response, done](const Status& s) {
72                     done(s);
73                     delete wrapped_request;
74                     delete wrapped_response;
75                   });
76   }
77 
78   // Returns a request object for use in calls to
79   // `RunGraphAsync()`. Ownership is transferred to the caller.
80   //
81   // The message returned from this method must only be used in a
82   // `RunGraph()` call on the same `WorkerInterface` instance.
CreateRunGraphRequest()83   virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
84     return new MutableProtoRunGraphRequest;
85   }
86 
87   // Returns a response object for use in calls to
88   // `RunGraphAsync()`. Ownership is transferred to the caller.
89   //
90   // The message returned from this method must only be used in a
91   // `RunGraph()` call on the same `WorkerInterface` instance.
CreateRunGraphResponse()92   virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
93     return new OwnedProtoRunGraphResponse;
94   }
95 
96   virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
97                                  CleanupGraphResponse* response,
98                                  StatusCallback done) = 0;
99 
100   virtual void CleanupAllAsync(const CleanupAllRequest* request,
101                                CleanupAllResponse* response,
102                                StatusCallback done) = 0;
103 
104   virtual void RecvTensorAsync(CallOptions* opts,
105                                const RecvTensorRequest* request,
106                                TensorResponse* response,
107                                StatusCallback done) = 0;
108 
109   virtual void LoggingAsync(const LoggingRequest* request,
110                             LoggingResponse* response, StatusCallback done) = 0;
111 
112   virtual void TracingAsync(const TracingRequest* request,
113                             TracingResponse* response, StatusCallback done) = 0;
114 
115   virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
116                             RecvBufResponse* response, StatusCallback done) = 0;
117 
118   virtual void CompleteGroupAsync(CallOptions* opts,
119                                   const CompleteGroupRequest* request,
120                                   CompleteGroupResponse* response,
121                                   StatusCallback done) = 0;
122 
123   virtual void CompleteInstanceAsync(CallOptions* ops,
124                                      const CompleteInstanceRequest* request,
125                                      CompleteInstanceResponse* response,
126                                      StatusCallback done) = 0;
127 
128   virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
129                                     GetStepSequenceResponse* response,
130                                     StatusCallback done) = 0;
131 
GetStatus(const GetStatusRequest * request,GetStatusResponse * response)132   Status GetStatus(const GetStatusRequest* request,
133                    GetStatusResponse* response) {
134     return CallAndWait(&ME::GetStatusAsync, request, response);
135   }
136 
CreateWorkerSession(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response)137   Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
138                              CreateWorkerSessionResponse* response) {
139     return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
140   }
141 
DeleteWorkerSession(const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response)142   Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
143                              DeleteWorkerSessionResponse* response) {
144     return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request,
145                                   response);
146   }
147 
RegisterGraph(const RegisterGraphRequest * request,RegisterGraphResponse * response)148   Status RegisterGraph(const RegisterGraphRequest* request,
149                        RegisterGraphResponse* response) {
150     return CallAndWait(&ME::RegisterGraphAsync, request, response);
151   }
152 
DeregisterGraph(const DeregisterGraphRequest * request,DeregisterGraphResponse * response)153   Status DeregisterGraph(const DeregisterGraphRequest* request,
154                          DeregisterGraphResponse* response) {
155     return CallAndWait(&ME::DeregisterGraphAsync, request, response);
156   }
157 
CleanupGraph(const CleanupGraphRequest * request,CleanupGraphResponse * response)158   Status CleanupGraph(const CleanupGraphRequest* request,
159                       CleanupGraphResponse* response) {
160     return CallAndWait(&ME::CleanupGraphAsync, request, response);
161   }
162 
CleanupAll(const CleanupAllRequest * request,CleanupAllResponse * response)163   Status CleanupAll(const CleanupAllRequest* request,
164                     CleanupAllResponse* response) {
165     return CallAndWait(&ME::CleanupAllAsync, request, response);
166   }
167 
Logging(const LoggingRequest * request,LoggingResponse * response)168   Status Logging(const LoggingRequest* request, LoggingResponse* response) {
169     return CallAndWait(&ME::LoggingAsync, request, response);
170   }
171 
Tracing(const TracingRequest * request,TracingResponse * response)172   Status Tracing(const TracingRequest* request, TracingResponse* response) {
173     return CallAndWait(&ME::TracingAsync, request, response);
174   }
175 
GetStepSequence(const GetStepSequenceRequest * request,GetStepSequenceResponse * response)176   Status GetStepSequence(const GetStepSequenceRequest* request,
177                          GetStepSequenceResponse* response) {
178     return CallAndWait(&ME::GetStepSequenceAsync, request, response);
179   }
180 
181  protected:
182   // Instances of WorkerInterface must be deleted by a call to
183   // WorkerCacheInterface::ReleaseWorker().
~WorkerInterface()184   virtual ~WorkerInterface() {}
185   friend class WorkerCacheInterface;
186 
187   // NOTE: This should only be called by implementations of this
188   // interface whose CreateRunGraphResponse() method returns a
189   // proto-based wrappers for the RunGraphResponse message.
get_proto_from_wrapper(MutableRunGraphResponseWrapper * wrapper)190   RunGraphResponse* get_proto_from_wrapper(
191       MutableRunGraphResponseWrapper* wrapper) {
192     return wrapper->get_proto();
193   }
194 
195  private:
196   typedef WorkerInterface ME;
197 
198   template <typename Method, typename Req, typename Resp>
CallAndWait(Method func,const Req * req,Resp * resp)199   Status CallAndWait(Method func, const Req* req, Resp* resp) {
200     Status ret;
201     Notification n;
202     (this->*func)(req, resp, [&ret, &n](const Status& s) {
203       ret = s;
204       n.Notify();
205     });
206     n.WaitForNotification();
207     return ret;
208   }
209 
210   template <typename Method, typename Req, typename Resp>
CallAndWaitWithOptions(Method func,const Req * req,Resp * resp)211   Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) {
212     CallOptions call_opts;
213     Status ret;
214     Notification n;
215     (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) {
216       ret = s;
217       n.Notify();
218     });
219     n.WaitForNotification();
220     return ret;
221   }
222 };
223 
224 }  // namespace tensorflow
225 
226 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
227