1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // GrpcMasterService implements the RPC service MasterSerivce.
17 //
18 // A GrpcMasterService maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A GrpcMasterService knows ahead of time local devices available as
23 // client devices.
24 //
25 // A GrpcMasterService discovers remote devices in the background and
26 // keeps track of statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on workers.
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
32 
33 #include "grpcpp/alarm.h"
34 #include "grpcpp/server_builder.h"
35 
36 #include "tensorflow/core/distributed_runtime/master.h"
37 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/tracing.h"
44 #include "tensorflow/core/protobuf/master.pb.h"
45 
46 namespace tensorflow {
47 
48 class GrpcMasterService : public AsyncServiceInterface {
49  public:
GrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)50   GrpcMasterService(Master* master, const ConfigProto& default_session_config,
51                     ::grpc::ServerBuilder* builder)
52       : master_impl_(master),
53         is_shutdown_(false),
54         default_session_config_(default_session_config) {
55     builder->RegisterService(&master_service_);
56     cq_ = builder->AddCompletionQueue();
57   }
58 
~GrpcMasterService()59   ~GrpcMasterService() override { delete shutdown_alarm_; }
60 
Shutdown()61   void Shutdown() override {
62     bool did_shutdown = false;
63     {
64       mutex_lock l(mu_);
65       if (!is_shutdown_) {
66         LOG(INFO) << "Shutting down GrpcMasterService.";
67         is_shutdown_ = true;
68         did_shutdown = true;
69       }
70     }
71     if (did_shutdown) {
72       // NOTE(mrry): This enqueues a special event (with a null tag)
73       // that causes the completion queue to be shut down on the
74       // polling thread.
75       shutdown_alarm_ =
76           new ::grpc::Alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
77     }
78   }
79 
80 // This macro creates a new request for the given RPC method name
81 // (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on
82 // `this->cq_`.
83 //
84 // This macro is invoked one or more times for each RPC method to
85 // ensure that there are sufficient completion queue entries to
86 // handle incoming requests without blocking.
87 //
88 // The implementation of the request handler for each RPC method
89 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
90 // to keep accepting new requests.
91 #define ENQUEUE_REQUEST(method, supports_cancel)                              \
92   do {                                                                        \
93     mutex_lock l(mu_);                                                        \
94     if (!is_shutdown_) {                                                      \
95       Call<GrpcMasterService, grpc::MasterService::AsyncService,              \
96            method##Request, method##Response>::                               \
97           EnqueueRequest(&master_service_, cq_.get(),                         \
98                          &grpc::MasterService::AsyncService::Request##method, \
99                          &GrpcMasterService::method##Handler,                 \
100                          (supports_cancel));                                  \
101     }                                                                         \
102   } while (0)
103 
HandleRPCsLoop()104   void HandleRPCsLoop() override {
105     ENQUEUE_REQUEST(CreateSession, true);
106     ENQUEUE_REQUEST(ExtendSession, false);
107     for (int i = 0; i < 100; ++i) {
108       ENQUEUE_REQUEST(PartialRunSetup, false);
109       ENQUEUE_REQUEST(RunStep, true);
110     }
111     ENQUEUE_REQUEST(CloseSession, false);
112     ENQUEUE_REQUEST(ListDevices, false);
113     ENQUEUE_REQUEST(Reset, false);
114     ENQUEUE_REQUEST(MakeCallable, false);
115     for (int i = 0; i < 100; ++i) {
116       ENQUEUE_REQUEST(RunCallable, true);
117     }
118     ENQUEUE_REQUEST(ReleaseCallable, false);
119 
120     void* tag;
121     bool ok;
122     while (cq_->Next(&tag, &ok)) {
123       UntypedCall<GrpcMasterService>::Tag* callback_tag =
124           static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
125       if (callback_tag) {
126         callback_tag->OnCompleted(this, ok);
127       } else {
128         // NOTE(mrry): A null `callback_tag` indicates that this is
129         // the shutdown alarm.
130         cq_->Shutdown();
131       }
132     }
133   }
134 
135  private:
136   Master* master_impl_ = nullptr;  // Not owned.
137   std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
138   grpc::MasterService::AsyncService master_service_;
139 
140   mutex mu_;
141   bool is_shutdown_ GUARDED_BY(mu_);
142   const ConfigProto default_session_config_;
143   ::grpc::Alarm* shutdown_alarm_ = nullptr;
144 
145   template <class RequestMessage, class ResponseMessage>
146   using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService,
147                           RequestMessage, ResponseMessage>;
148 
149   // RPC handler for creating a session.
CreateSessionHandler(MasterCall<CreateSessionRequest,CreateSessionResponse> * call)150   void CreateSessionHandler(
151       MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
152     CreateSessionRequest* rewritten_req = new CreateSessionRequest;
153     rewritten_req->mutable_config()->MergeFrom(default_session_config_);
154     rewritten_req->MergeFrom(call->request);
155     master_impl_->CreateSession(rewritten_req, &call->response,
156                                 [call, rewritten_req](const Status& status) {
157                                   call->SendResponse(ToGrpcStatus(status));
158                                   delete rewritten_req;
159                                 });
160     ENQUEUE_REQUEST(CreateSession, true);
161   }
162 
163   // RPC handler for extending a session.
ExtendSessionHandler(MasterCall<ExtendSessionRequest,ExtendSessionResponse> * call)164   void ExtendSessionHandler(
165       MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) {
166     master_impl_->ExtendSession(&call->request, &call->response,
167                                 [call](const Status& status) {
168                                   call->SendResponse(ToGrpcStatus(status));
169                                 });
170     ENQUEUE_REQUEST(ExtendSession, false);
171   }
172 
173   // RPC handler for setting up a partial run call.
PartialRunSetupHandler(MasterCall<PartialRunSetupRequest,PartialRunSetupResponse> * call)174   void PartialRunSetupHandler(
175       MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) {
176     master_impl_->PartialRunSetup(&call->request, &call->response,
177                                   [call](const Status& status) {
178                                     call->SendResponse(ToGrpcStatus(status));
179                                   });
180     ENQUEUE_REQUEST(PartialRunSetup, false);
181   }
182 
183   // RPC handler for running one step in a session.
RunStepHandler(MasterCall<RunStepRequest,RunStepResponse> * call)184   void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
185     auto* trace = TraceRpc("RunStep/Server", call->client_metadata());
186     CallOptions* call_opts = new CallOptions;
187     if (call->request.options().timeout_in_ms() > 0) {
188       call_opts->SetTimeout(call->request.options().timeout_in_ms());
189     } else {
190       call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
191     }
192     RunStepRequestWrapper* wrapped_request =
193         new ProtoRunStepRequest(&call->request);
194     MutableRunStepResponseWrapper* wrapped_response =
195         new NonOwnedProtoRunStepResponse(&call->response);
196     call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
197     master_impl_->RunStep(
198         call_opts, wrapped_request, wrapped_response,
199         [call, call_opts, wrapped_request, wrapped_response,
200          trace](const Status& status) {
201           call->ClearCancelCallback();
202           delete call_opts;
203           delete wrapped_request;
204           delete trace;
205           if (call->request.store_errors_in_response_body() && !status.ok()) {
206             call->response.set_status_code(status.code());
207             call->response.set_status_error_message(status.error_message());
208             call->SendResponse(ToGrpcStatus(Status::OK()));
209           } else {
210             call->SendResponse(ToGrpcStatus(status));
211           }
212         });
213     ENQUEUE_REQUEST(RunStep, true);
214   }
215 
216   // RPC handler for deleting a session.
CloseSessionHandler(MasterCall<CloseSessionRequest,CloseSessionResponse> * call)217   void CloseSessionHandler(
218       MasterCall<CloseSessionRequest, CloseSessionResponse>* call) {
219     master_impl_->CloseSession(&call->request, &call->response,
220                                [call](const Status& status) {
221                                  call->SendResponse(ToGrpcStatus(status));
222                                });
223     ENQUEUE_REQUEST(CloseSession, false);
224   }
225 
226   // RPC handler for listing devices.
ListDevicesHandler(MasterCall<ListDevicesRequest,ListDevicesResponse> * call)227   void ListDevicesHandler(
228       MasterCall<ListDevicesRequest, ListDevicesResponse>* call) {
229     master_impl_->ListDevices(&call->request, &call->response,
230                               [call](const Status& status) {
231                                 call->SendResponse(ToGrpcStatus(status));
232                               });
233     ENQUEUE_REQUEST(ListDevices, false);
234   }
235 
236   // RPC handler for resetting all sessions.
ResetHandler(MasterCall<ResetRequest,ResetResponse> * call)237   void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) {
238     master_impl_->Reset(&call->request, &call->response,
239                         [call](const Status& status) {
240                           call->SendResponse(ToGrpcStatus(status));
241                         });
242     ENQUEUE_REQUEST(Reset, false);
243   }
244 
245   // RPC handler for making a callable.
MakeCallableHandler(MasterCall<MakeCallableRequest,MakeCallableResponse> * call)246   void MakeCallableHandler(
247       MasterCall<MakeCallableRequest, MakeCallableResponse>* call) {
248     master_impl_->MakeCallable(&call->request, &call->response,
249                                [call](const Status& status) {
250                                  call->SendResponse(ToGrpcStatus(status));
251                                });
252     ENQUEUE_REQUEST(MakeCallable, false);
253   }
254 
255   // RPC handler for running a callable.
RunCallableHandler(MasterCall<RunCallableRequest,RunCallableResponse> * call)256   void RunCallableHandler(
257       MasterCall<RunCallableRequest, RunCallableResponse>* call) {
258     auto* trace = TraceRpc("RunCallable/Server", call->client_metadata());
259     CallOptions* call_opts = new CallOptions;
260     // The timeout may be overridden by a non-zero timeout in the
261     // callable's `RunOptions`; this overriding will happen inside the
262     // `MasterSession` implementation.
263     call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
264     call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
265     master_impl_->RunCallable(call_opts, &call->request, &call->response,
266                               [call, call_opts, trace](const Status& status) {
267                                 call->ClearCancelCallback();
268                                 delete call_opts;
269                                 delete trace;
270                                 call->SendResponse(ToGrpcStatus(status));
271                               });
272     ENQUEUE_REQUEST(RunCallable, false);
273   }
274 
275   // RPC handler for making a callable.
ReleaseCallableHandler(MasterCall<ReleaseCallableRequest,ReleaseCallableResponse> * call)276   void ReleaseCallableHandler(
277       MasterCall<ReleaseCallableRequest, ReleaseCallableResponse>* call) {
278     master_impl_->ReleaseCallable(&call->request, &call->response,
279                                   [call](const Status& status) {
280                                     call->SendResponse(ToGrpcStatus(status));
281                                   });
282     ENQUEUE_REQUEST(ReleaseCallable, false);
283   }
284 
285 #undef ENQUEUE_REQUEST
286 
287   // Start tracing, including the ID attached to the RPC.
TraceRpc(StringPiece name,const std::multimap<::grpc::string_ref,::grpc::string_ref> & metadata)288   tracing::ScopedActivity* TraceRpc(
289       StringPiece name,
290       const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
291     StringPiece id;
292     auto it = metadata.find(GrpcIdKey());
293     if (it != metadata.end()) {
294       id = StringPiece(it->second.data(), it->second.size());
295     }
296     return new tracing::ScopedActivity(name, id);
297   }
298 
299   TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
300 };
301 
NewGrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)302 AsyncServiceInterface* NewGrpcMasterService(
303     Master* master, const ConfigProto& default_session_config,
304     ::grpc::ServerBuilder* builder) {
305   return new GrpcMasterService(master, default_session_config, builder);
306 }
307 
308 }  // end namespace tensorflow
309