1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
17 
18 #include <utility>
19 
20 #include "grpcpp/generic/generic_stub.h"
21 #include "grpcpp/grpcpp.h"
22 
23 #include "tensorflow/core/common_runtime/process_util.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
28 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/tracing.h"
37 #include "tensorflow/core/protobuf/transport_options.pb.h"
38 #include "tensorflow/core/protobuf/worker.pb.h"
39 
40 namespace tensorflow {
41 
42 const int kMaxWorkerRpcRetries = 10;
43 
44 class GrpcRemoteWorker : public WorkerInterface {
45  public:
GrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger)46   explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
47                             ::grpc::CompletionQueue* completion_queue,
48                             thread::ThreadPool* callback_threadpool,
49                             WorkerCacheLogger* logger)
50       : channel_(std::move(channel)),
51         stub_(channel_),
52         cq_(completion_queue),
53         callback_threadpool_(callback_threadpool),
54         getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
55         createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
56         deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
57         registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
58         deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
59         rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
60         cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
61         cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
62         recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
63         recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
64         logging_(Method(GrpcWorkerMethod::kLogging)),
65         tracing_(Method(GrpcWorkerMethod::kTracing)),
66         completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
67         instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
68         getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
69         logger_(logger) {}
70 
~GrpcRemoteWorker()71   ~GrpcRemoteWorker() override {}
72 
GetStatusAsync(const GetStatusRequest * request,GetStatusResponse * response,StatusCallback done)73   void GetStatusAsync(const GetStatusRequest* request,
74                       GetStatusResponse* response,
75                       StatusCallback done) override {
76     IssueRequest(request, response, getstatus_, std::move(done));
77   }
78 
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)79   void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
80                                 CreateWorkerSessionResponse* response,
81                                 StatusCallback done) override {
82     IssueRequest(request, response, createworkersession_, std::move(done));
83   }
84 
DeleteWorkerSessionAsync(CallOptions * call_opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)85   void DeleteWorkerSessionAsync(CallOptions* call_opts,
86                                 const DeleteWorkerSessionRequest* request,
87                                 DeleteWorkerSessionResponse* response,
88                                 StatusCallback done) override {
89     IssueRequest(request, response, deleteworkersession_, std::move(done),
90                  call_opts);
91   }
92 
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)93   void RegisterGraphAsync(const RegisterGraphRequest* request,
94                           RegisterGraphResponse* response,
95                           StatusCallback done) override {
96     IssueRequest(request, response, registergraph_, std::move(done));
97   }
98 
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)99   void DeregisterGraphAsync(const DeregisterGraphRequest* request,
100                             DeregisterGraphResponse* response,
101                             StatusCallback done) override {
102     IssueRequest(request, response, deregistergraph_, std::move(done));
103   }
104 
RunGraphAsync(CallOptions * call_opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)105   void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
106                      RunGraphResponse* response, StatusCallback done) override {
107     IssueRequest(request, response, rungraph_, std::move(done), call_opts);
108   }
RunGraphAsync(CallOptions * call_opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)109   void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
110                      MutableRunGraphResponseWrapper* response,
111                      StatusCallback done) override {
112     IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
113                  rungraph_, std::move(done), call_opts);
114   }
115 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)116   void CleanupGraphAsync(const CleanupGraphRequest* request,
117                          CleanupGraphResponse* response,
118                          StatusCallback done) override {
119     IssueRequest(request, response, cleanupgraph_, std::move(done));
120   }
121 
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)122   void CleanupAllAsync(const CleanupAllRequest* request,
123                        CleanupAllResponse* response,
124                        StatusCallback done) override {
125     IssueRequest(request, response, cleanupall_, std::move(done));
126   }
127 
RecvBufAsync(CallOptions * call_opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)128   void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request,
129                     RecvBufResponse* response, StatusCallback done) override {
130     int64 start_usec = Env::Default()->NowMicros();
131     // Type-specialized logging for this method.
132     bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
133     StatusCallback wrapper_done;
134     const StatusCallback* cb_to_use;
135     if (!logging_active) {
136       cb_to_use = &done;  // No additional work to do, so just use done directly
137     } else {
138       wrapper_done = [this, request, response, done, start_usec](Status s) {
139         if (logger_->LoggingActive()) {
140           int64 end_usec = Env::Default()->NowMicros();
141           int64 step_id = request->step_id();
142           RecvBufRespExtra extra;
143           response->transport_options().UnpackTo(&extra);
144           int64 num_bytes = 0;
145           for (const auto& chunk : extra.tensor_content()) {
146             num_bytes += chunk.size();
147           }
148           int64 send_start_usec = start_usec;
149           // Prefer start time reported by the sender, if available.
150           if (response->send_start_micros()) {
151             send_start_usec = std::max(
152                 start_usec, static_cast<int64>(response->send_start_micros()));
153             send_start_usec = std::min(send_start_usec, end_usec - 1);
154           }
155           const string& key = request->buf_rendezvous_key();
156           logger_->RecordDataTransfer(
157               step_id, send_start_usec, end_usec, key, request->src_device(),
158               request->dst_device(), num_bytes, "", "RecvBuf");
159         }
160         VLOG(2) << "done callback, req: " << request->DebugString()
161                 << " response " << response->DebugString();
162         done(s);
163       };
164       cb_to_use = &wrapper_done;
165     }
166 
167     IssueRequest(request, response, recvbuf_, *cb_to_use, call_opts);
168   }
169 
CompleteGroupAsync(CallOptions * call_opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)170   void CompleteGroupAsync(CallOptions* call_opts,
171                           const CompleteGroupRequest* request,
172                           CompleteGroupResponse* response,
173                           StatusCallback done) override {
174     IssueRequest(request, response, completegroup_, std::move(done), call_opts);
175   }
176 
CompleteInstanceAsync(CallOptions * call_opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)177   void CompleteInstanceAsync(CallOptions* call_opts,
178                              const CompleteInstanceRequest* request,
179                              CompleteInstanceResponse* response,
180                              StatusCallback done) override {
181     IssueRequest(request, response, instancesource_, std::move(done),
182                  call_opts);
183   }
184 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)185   void GetStepSequenceAsync(const GetStepSequenceRequest* request,
186                             GetStepSequenceResponse* response,
187                             StatusCallback done) override {
188     IssueRequest(request, response, getstepsequence_, std::move(done));
189   }
190 
RecvTensorAsync(CallOptions * call_opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)191   void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
192                        TensorResponse* response, StatusCallback done) override {
193     VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
194     int64 start_usec = Env::Default()->NowMicros();
195     // Type-specialized logging for this method.
196     bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
197     StatusCallback wrapper_done;
198     const StatusCallback* cb_to_use;
199     if (!logging_active) {
200       cb_to_use = &done;  // No additional work to do, so just use done directly
201     } else {
202       wrapper_done = [this, request, response, done, start_usec](Status s) {
203         if (logger_->LoggingActive()) {
204           int64 end_usec = Env::Default()->NowMicros();
205           int64 step_id = request->step_id();
206           int64 bytes = response->tensor().TotalBytes();
207           int64 send_start_usec = start_usec;
208           // If a send start time was reported by the other side, use
209           // that instead.  Maybe we should mark the display if we're using
210           // our local time instead of the remote start time?
211           if (response->metadata().send_start_micros()) {
212             // send_start_micros is the timestamp taken when the
213             // remote machine began to send the RecvTensor response.
214             // Due to clock skew between source and dest machines, it
215             // is possible that send_start_micros can be larger than
216             // end_usec or less than start_usec.
217             //
218             // To respect causality, we enforce the invariants that
219             // the RecvTensor response can not have been sent before
220             // the RecvTensor request, and must have been sent before
221             // it was received.
222             send_start_usec = std::max(
223                 start_usec,
224                 static_cast<int64>(response->metadata().send_start_micros()));
225             send_start_usec = std::min(send_start_usec, end_usec - 1);
226           }
227           const string& key = request->rendezvous_key();
228           std::vector<string> key_parts = str_util::Split(key, ';');
229           if (key_parts.size() != 5) {
230             LOG(WARNING) << "Bad key: " << key;
231           } else {
232             logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
233                                       key_parts[3],  // tensor name
234                                       key_parts[0],  // src_device
235                                       key_parts[2],  // dst_device
236                                       bytes);
237           }
238         }
239         VLOG(2) << "done callback, req: " << request->DebugString()
240                 << " response " << response->metadata().DebugString();
241         done(s);
242       };
243       cb_to_use = &wrapper_done;
244     }
245 
246     IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
247   }
248 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)249   void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
250                     StatusCallback done) override {
251     IssueRequest(request, response, logging_, done);
252   }
253 
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)254   void TracingAsync(const TracingRequest* request, TracingResponse* response,
255                     StatusCallback done) override {
256     IssueRequest(request, response, tracing_, done);
257   }
258 
259  private:
260   // Utility method for issuing a generic asynchronous request. The
261   // given callback, `done`, will be called when the RPC completes.
IssueRequest(const protobuf::Message * request,protobuf::Message * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr,int max_retries=kMaxWorkerRpcRetries)262   void IssueRequest(const protobuf::Message* request,
263                     protobuf::Message* response, const ::grpc::string& method,
264                     StatusCallback done, CallOptions* call_opts = nullptr,
265                     int max_retries = kMaxWorkerRpcRetries) {
266     new RPCState<protobuf::Message>(&stub_, cq_, method, *request, response,
267                                     std::move(done), call_opts,
268                                     callback_threadpool_, max_retries);
269   }
IssueRequest(const protobuf::Message * request,TensorResponse * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr,int max_retries=kMaxWorkerRpcRetries)270   void IssueRequest(const protobuf::Message* request, TensorResponse* response,
271                     const ::grpc::string& method, StatusCallback done,
272                     CallOptions* call_opts = nullptr,
273                     int max_retries = kMaxWorkerRpcRetries) {
274     new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
275                                  std::move(done), call_opts,
276                                  callback_threadpool_, max_retries);
277   }
278 
279   // Helper function for initializing the RpcMethod objects below.
Method(GrpcWorkerMethod id)280   const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
281 
282   SharedGrpcChannelPtr channel_;
283   ::grpc::GenericStub stub_;
284   ::grpc::CompletionQueue* cq_;
285   thread::ThreadPool* callback_threadpool_;
286 
287   const ::grpc::string getstatus_;
288   const ::grpc::string createworkersession_;
289   const ::grpc::string deleteworkersession_;
290   const ::grpc::string registergraph_;
291   const ::grpc::string deregistergraph_;
292   const ::grpc::string rungraph_;
293   const ::grpc::string cleanupgraph_;
294   const ::grpc::string cleanupall_;
295   const ::grpc::string recvtensor_;
296   const ::grpc::string recvbuf_;
297   const ::grpc::string logging_;
298   const ::grpc::string tracing_;
299   const ::grpc::string completegroup_;
300   const ::grpc::string instancesource_;
301   const ::grpc::string getstepsequence_;
302 
303   // Support for logging.
304   WorkerCacheLogger* logger_;
305 
306   TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
307 };
308 
NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger)309 WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
310                                      ::grpc::CompletionQueue* completion_queue,
311                                      thread::ThreadPool* callback_threadpool,
312                                      WorkerCacheLogger* logger) {
313   return new GrpcRemoteWorker(std::move(channel), completion_queue,
314                               callback_threadpool, logger);
315 }
316 
317 }  // namespace tensorflow
318