1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
18 
19 #include <functional>
20 
21 #include "tensorflow/core/distributed_runtime/call_options.h"
22 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
23 #include "tensorflow/core/lib/core/notification.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/protobuf/worker.pb.h"
27 
28 namespace tensorflow {
29 
30 // Status callback.
31 typedef std::function<void(const Status&)> StatusCallback;
32 
33 // Custom decoder for a response to RecvTensorAsync.
34 class TensorResponse;
35 
36 // Interface for talking with the TensorFlow Worker service.
37 class WorkerInterface {
38  public:
39   virtual void GetStatusAsync(CallOptions* opts,
40                               const GetStatusRequest* request,
41                               GetStatusResponse* response, bool fail_fast,
42                               StatusCallback done) = 0;
43 
44   virtual void CreateWorkerSessionAsync(
45       const CreateWorkerSessionRequest* request,
46       CreateWorkerSessionResponse* response, StatusCallback done) = 0;
47 
48   virtual void DeleteWorkerSessionAsync(
49       CallOptions* opts, const DeleteWorkerSessionRequest* request,
50       DeleteWorkerSessionResponse* response, StatusCallback done) = 0;
51 
52   virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
53                                   RegisterGraphResponse* response,
54                                   StatusCallback done) = 0;
55 
56   virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
57                                     DeregisterGraphResponse* response,
58                                     StatusCallback done) = 0;
59 
60   virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
61                              MutableRunGraphResponseWrapper* response,
62                              StatusCallback done) = 0;
63 
RunGraphAsync(CallOptions * opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)64   virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
65                              RunGraphResponse* response, StatusCallback done) {
66     RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
67     MutableRunGraphResponseWrapper* wrapped_response =
68         new NonOwnedProtoRunGraphResponse(response);
69     RunGraphAsync(opts, wrapped_request, wrapped_response,
70                   [wrapped_request, wrapped_response,
71                    done = std::move(done)](const Status& s) {
72                     done(s);
73                     delete wrapped_request;
74                     delete wrapped_response;
75                   });
76   }
77 
78   // Returns a request object for use in calls to
79   // `RunGraphAsync()`. Ownership is transferred to the caller.
80   //
81   // The message returned from this method must only be used in a
82   // `RunGraph()` call on the same `WorkerInterface` instance.
CreateRunGraphRequest()83   virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
84     return new MutableProtoRunGraphRequest;
85   }
86 
87   // Returns a response object for use in calls to
88   // `RunGraphAsync()`. Ownership is transferred to the caller.
89   //
90   // The message returned from this method must only be used in a
91   // `RunGraph()` call on the same `WorkerInterface` instance.
CreateRunGraphResponse()92   virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
93     return new OwnedProtoRunGraphResponse;
94   }
95 
96   virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
97                                  CleanupGraphResponse* response,
98                                  StatusCallback done) = 0;
99 
100   virtual void CleanupAllAsync(const CleanupAllRequest* request,
101                                CleanupAllResponse* response,
102                                StatusCallback done) = 0;
103 
104   virtual void RecvTensorAsync(CallOptions* opts,
105                                const RecvTensorRequest* request,
106                                TensorResponse* response,
107                                StatusCallback done) = 0;
108 
109   virtual void LoggingAsync(const LoggingRequest* request,
110                             LoggingResponse* response, StatusCallback done) = 0;
111 
112   virtual void TracingAsync(const TracingRequest* request,
113                             TracingResponse* response, StatusCallback done) = 0;
114 
115   virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
116                             RecvBufResponse* response, StatusCallback done) = 0;
117 
118   virtual void CompleteGroupAsync(CallOptions* opts,
119                                   const CompleteGroupRequest* request,
120                                   CompleteGroupResponse* response,
121                                   StatusCallback done) = 0;
122 
123   virtual void CompleteInstanceAsync(CallOptions* ops,
124                                      const CompleteInstanceRequest* request,
125                                      CompleteInstanceResponse* response,
126                                      StatusCallback done) = 0;
127 
128   virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
129                                     GetStepSequenceResponse* response,
130                                     StatusCallback done) = 0;
131 
GetStatus(const GetStatusRequest * request,GetStatusResponse * response)132   Status GetStatus(const GetStatusRequest* request,
133                    GetStatusResponse* response) {
134     Status ret;
135     Notification n;
136     GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
137                    [&ret, &n](const Status& s) {
138                      ret = s;
139                      n.Notify();
140                    });
141     n.WaitForNotification();
142     return ret;
143   }
144 
CreateWorkerSession(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response)145   Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
146                              CreateWorkerSessionResponse* response) {
147     return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
148   }
149 
DeleteWorkerSession(const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response)150   Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
151                              DeleteWorkerSessionResponse* response) {
152     return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request,
153                                   response);
154   }
155 
RegisterGraph(const RegisterGraphRequest * request,RegisterGraphResponse * response)156   Status RegisterGraph(const RegisterGraphRequest* request,
157                        RegisterGraphResponse* response) {
158     return CallAndWait(&ME::RegisterGraphAsync, request, response);
159   }
160 
DeregisterGraph(const DeregisterGraphRequest * request,DeregisterGraphResponse * response)161   Status DeregisterGraph(const DeregisterGraphRequest* request,
162                          DeregisterGraphResponse* response) {
163     return CallAndWait(&ME::DeregisterGraphAsync, request, response);
164   }
165 
CleanupGraph(const CleanupGraphRequest * request,CleanupGraphResponse * response)166   Status CleanupGraph(const CleanupGraphRequest* request,
167                       CleanupGraphResponse* response) {
168     return CallAndWait(&ME::CleanupGraphAsync, request, response);
169   }
170 
CleanupAll(const CleanupAllRequest * request,CleanupAllResponse * response)171   Status CleanupAll(const CleanupAllRequest* request,
172                     CleanupAllResponse* response) {
173     return CallAndWait(&ME::CleanupAllAsync, request, response);
174   }
175 
Logging(const LoggingRequest * request,LoggingResponse * response)176   Status Logging(const LoggingRequest* request, LoggingResponse* response) {
177     return CallAndWait(&ME::LoggingAsync, request, response);
178   }
179 
Tracing(const TracingRequest * request,TracingResponse * response)180   Status Tracing(const TracingRequest* request, TracingResponse* response) {
181     return CallAndWait(&ME::TracingAsync, request, response);
182   }
183 
GetStepSequence(const GetStepSequenceRequest * request,GetStepSequenceResponse * response)184   Status GetStepSequence(const GetStepSequenceRequest* request,
185                          GetStepSequenceResponse* response) {
186     return CallAndWait(&ME::GetStepSequenceAsync, request, response);
187   }
188 
189  protected:
190   // Instances of WorkerInterface must be deleted by a call to
191   // WorkerCacheInterface::ReleaseWorker().
~WorkerInterface()192   virtual ~WorkerInterface() {}
193   friend class WorkerCacheInterface;
194 
195   // NOTE: This should only be called by implementations of this
196   // interface whose CreateRunGraphResponse() method returns a
197   // proto-based wrappers for the RunGraphResponse message.
get_proto_from_wrapper(MutableRunGraphResponseWrapper * wrapper)198   RunGraphResponse* get_proto_from_wrapper(
199       MutableRunGraphResponseWrapper* wrapper) {
200     return wrapper->get_proto();
201   }
202 
203  private:
204   typedef WorkerInterface ME;
205 
206   template <typename Method, typename Req, typename Resp>
CallAndWait(Method func,const Req * req,Resp * resp)207   Status CallAndWait(Method func, const Req* req, Resp* resp) {
208     Status ret;
209     Notification n;
210     (this->*func)(req, resp, [&ret, &n](const Status& s) {
211       ret = s;
212       n.Notify();
213     });
214     n.WaitForNotification();
215     return ret;
216   }
217 
218   template <typename Method, typename Req, typename Resp>
CallAndWaitWithOptions(Method func,const Req * req,Resp * resp)219   Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) {
220     CallOptions call_opts;
221     Status ret;
222     Notification n;
223     (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) {
224       ret = s;
225       n.Notify();
226     });
227     n.WaitForNotification();
228     return ret;
229   }
230 };
231 
232 }  // namespace tensorflow
233 
234 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
235