1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
17 
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "grpcpp/alarm.h"
24 #include "grpcpp/server_builder.h"
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/common_runtime/local_device.h"
32 #include "tensorflow/core/common_runtime/process_util.h"
33 #include "tensorflow/core/common_runtime/step_stats_collector.h"
34 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
35 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
42 #include "tensorflow/core/distributed_runtime/worker.h"
43 #include "tensorflow/core/distributed_runtime/worker_cache.h"
44 #include "tensorflow/core/distributed_runtime/worker_session.h"
45 #include "tensorflow/core/framework/cancellation.h"
46 #include "tensorflow/core/framework/collective.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/lib/strings/stringprintf.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/tracing.h"
56 #include "tensorflow/core/protobuf/transport_options.pb.h"
57 #include "tensorflow/core/protobuf/worker.pb.h"
58 
59 namespace tensorflow {
60 
61 namespace {
62 
63 // This macro creates a new request for the given RPC method name
64 // (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on
65 // `this->cq_`.
66 //
67 // This macro is invoked one or more times for each RPC method to
68 // ensure that there are sufficient completion queue entries to
69 // handle incoming requests without blocking.
70 //
71 // The implementation of the request handler for each RPC method
72 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
73 // to keep accepting new requests.
74 #define ENQUEUE_REQUEST(method, supports_cancel)                             \
75   do {                                                                       \
76     mutex_lock l(shutdown_mu_);                                              \
77     if (!is_shutdown_) {                                                     \
78       Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,       \
79            method##Request, method##Response>::                              \
80           EnqueueRequestForMethod(                                           \
81               worker_service_, cq_.get(),                                    \
82               static_cast<int>(GrpcWorkerMethod::k##method),                 \
83               &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
84     }                                                                        \
85   } while (0)
86 
87 #define SETUP_FOR_REQUEST(method, default_depth, supports_cancel)              \
88   for (int i = 0;                                                              \
89        i < gtl::FindWithDefault(queue_depth_,                                  \
90                                 static_cast<int>(GrpcWorkerMethod::k##method), \
91                                 default_depth);                                \
92        ++i) {                                                                  \
93     ENQUEUE_REQUEST(method, supports_cancel);                                  \
94   }
95 
96 // GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
97 // requests.  Each thread operates on an independent completion queue.
98 class GrpcWorkerServiceThread {
99  public:
GrpcWorkerServiceThread(GrpcWorker * worker,::grpc::ServerBuilder * builder,std::unordered_map<int,int> queue_depth,GrpcResponseCache * cache,grpc::WorkerService::AsyncService * worker_service)100   explicit GrpcWorkerServiceThread(
101       GrpcWorker* worker, ::grpc::ServerBuilder* builder,
102       std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
103       grpc::WorkerService::AsyncService* worker_service)
104       : worker_(worker),
105         queue_depth_(queue_depth),
106         cache_(cache),
107         worker_service_(worker_service),
108         is_shutdown_(false) {
109     cq_ = builder->AddCompletionQueue();
110   }
111 
Start()112   void Start() {
113     thread_.reset(
114         worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
115                                          [this]() { HandleRPCsLoop(); }));
116   }
117 
Join()118   void Join() { thread_.reset(); }  // Blocks until thread exits
119 
Shutdown()120   void Shutdown() {
121     {
122       mutex_lock lock(shutdown_mu_);
123       is_shutdown_ = true;
124     }
125     cq_->Shutdown();
126   }
127 
128  private:
129   // Add one or more completion queue entries for each worker method, then
130   // begin servicing requests from the completion queue.
HandleRPCsLoop()131   void HandleRPCsLoop() {
132     // TODO(ncteisen): This may require performance engineering. We can
133     // change the number of threads, the number of handlers per thread,
134     // or even decide to specialize certain threads to certain methods.
135     SETUP_FOR_REQUEST(GetStatus, 1, false);
136     SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
137     SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
138     SETUP_FOR_REQUEST(CleanupAll, 1, false);
139     SETUP_FOR_REQUEST(RegisterGraph, 1, false);
140     SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
141     SETUP_FOR_REQUEST(Logging, 1, false);
142     SETUP_FOR_REQUEST(Tracing, 1, false);
143     SETUP_FOR_REQUEST(CompleteGroup, 10, true);
144     SETUP_FOR_REQUEST(CompleteInstance, 10, true);
145     SETUP_FOR_REQUEST(GetStepSequence, 10, true);
146     SETUP_FOR_REQUEST(RecvBuf, 500, true);
147     SETUP_FOR_REQUEST(RunGraph, 100, true);
148     SETUP_FOR_REQUEST(CleanupGraph, 100, false);
149 
150     // TODO(ncteisen): Determine a better policy for enqueuing the
151     // appropriate number of each request type.
152     for (int i = 0;
153          i < gtl::FindWithDefault(
154                  queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
155                  1000);
156          ++i) {
157       EnqueueRecvTensorRequestRaw();
158     }
159 
160     void* tag;
161     bool ok;
162 
163     while (cq_->Next(&tag, &ok)) {
164       UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
165           static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
166       CHECK(callback_tag);
167       callback_tag->OnCompleted(this, ok);
168     }
169   }
170 
171  private:
Schedule(std::function<void ()> f)172   void Schedule(std::function<void()> f) {
173     worker_->env()->compute_pool->Schedule(std::move(f));
174   }
175 
176   // The following section contains one request handler method per
177   // RPC. The `FooHandler` method is called (indirectly) by
178   // `HandleRPCsLoop()` when the next Foo RPC is received. Each
179   // `FooHandler` call schedules a closure on `worker_->env()->compute_pool`,
180   // and is responsible for requesting the next Foo call by calling
181   // `ENQUEUE_REQUEST(Foo)`.
182   template <class RequestMessage, class ResponseMessage>
183   using WorkerCall =
184       Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
185            RequestMessage, ResponseMessage>;
186 
187   // Handle all non-cancellable simple methods with a standard wrapper.
188 #define HANDLE_CALL(method)                                                   \
189   void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
190     Schedule([this, call]() {                                                 \
191       Status s = worker_->method(&call->request, &call->response);            \
192       if (!s.ok()) {                                                          \
193         VLOG(1) << "Bad response from " << #method << ": " << s;              \
194       }                                                                       \
195       call->SendResponse(ToGrpcStatus(s));                                    \
196     });                                                                       \
197     ENQUEUE_REQUEST(method, false);                                           \
198   }
199 
200   HANDLE_CALL(GetStatus);
201   HANDLE_CALL(CreateWorkerSession);
202   HANDLE_CALL(DeleteWorkerSession);
203   HANDLE_CALL(CleanupAll);
204   HANDLE_CALL(RegisterGraph);
205   HANDLE_CALL(DeregisterGraph);
206   HANDLE_CALL(CleanupGraph);
207   HANDLE_CALL(Logging);
208   HANDLE_CALL(Tracing);
209 
210 #undef HANDLE_CALL
211 
GetStepSequenceHandler(WorkerCall<GetStepSequenceRequest,GetStepSequenceResponse> * call)212   void GetStepSequenceHandler(
213       WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
214     Schedule([this, call]() {
215       worker_->GetStepSequenceAsync(
216           &call->request, &call->response, [call](const Status& s) {
217             VLOG(1) << "Bad response from GetStepSequence:" << s;
218             call->SendResponse(ToGrpcStatus(s));
219           });
220     });
221     ENQUEUE_REQUEST(GetStepSequence, true);
222   }
223 
RunGraphHandler(WorkerCall<RunGraphRequest,RunGraphResponse> * call)224   void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
225     Schedule([this, call]() {
226       CallOptions* call_opts = new CallOptions;
227       ProtoRunGraphRequest* wrapped_request =
228           new ProtoRunGraphRequest(&call->request);
229       NonOwnedProtoRunGraphResponse* wrapped_response =
230           new NonOwnedProtoRunGraphResponse(&call->response);
231       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
232       auto done_cb = [call, call_opts, wrapped_request,
233                       wrapped_response](const Status& s) {
234         VLOG(1) << "RunGraph::Done";
235         if (!s.ok()) {
236           VLOG(1) << "Bad response from RunGraph:" << s;
237         }
238         call->ClearCancelCallback();
239         delete call_opts;
240         delete wrapped_request;
241         delete wrapped_response;
242         call->SendResponse(ToGrpcStatus(s));
243       };
244 
245       auto compute_fn = [this, call_opts, wrapped_request,
246                          wrapped_response](StatusCallback done) {
247         worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
248                                done);
249       };
250 
251       if (cache_) {
252         string request_key = call->request.ShortDebugString();
253         cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
254                                 compute_fn, done_cb);
255       } else {
256         compute_fn(done_cb);
257       }
258     });
259     ENQUEUE_REQUEST(RunGraph, true);
260   }
261 
RecvTensorHandlerRaw(WorkerCall<RecvTensorRequest,::grpc::ByteBuffer> * call)262   void RecvTensorHandlerRaw(
263       WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
264     Schedule([this, call]() {
265       CallOptions* call_opts = new CallOptions;
266       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
267 
268       auto done_cb = [call, call_opts](const Status& s) {
269         call->ClearCancelCallback();
270         delete call_opts;
271         if (!s.ok()) {
272           VLOG(1) << "Bad response from RecvTensor:" << s;
273         }
274         call->SendResponse(ToGrpcStatus(s));
275       };
276 
277       auto compute_fn = [this, &call_opts, &call](StatusCallback done) {
278         worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response,
279                                      done);
280       };
281 
282       if (cache_) {
283         string request_key = call->request.ShortDebugString();
284         cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
285                                 compute_fn, done_cb);
286       } else {
287         compute_fn(done_cb);
288       }
289     });
290     EnqueueRecvTensorRequestRaw();
291   }
292 
RecvBufHandler(WorkerCall<RecvBufRequest,RecvBufResponse> * call)293   void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
294     Schedule([this, call]() {
295       CallOptions* call_opts = new CallOptions;
296       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
297       worker_->RecvBufAsync(call_opts, &call->request, &call->response,
298                             [call, call_opts](const Status& s) {
299                               call->ClearCancelCallback();
300                               delete call_opts;
301                               if (!s.ok()) {
302                                 VLOG(1) << "Bad response from RecvBuf:" << s;
303                               }
304                               call->SendResponse(ToGrpcStatus(s));
305                             });
306     });
307     ENQUEUE_REQUEST(RecvBuf, true);
308   }
309 
CompleteGroupHandler(WorkerCall<CompleteGroupRequest,CompleteGroupResponse> * call)310   void CompleteGroupHandler(
311       WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
312     Schedule([this, call]() {
313       CallOptions* call_opts = new CallOptions;
314       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
315       worker_->CompleteGroupAsync(
316           call_opts, &call->request, &call->response,
317           [call, call_opts](const Status& s) {
318             call->ClearCancelCallback();
319             delete call_opts;
320             if (!s.ok()) {
321               VLOG(1) << "Bad response from CompleteGroup:" << s;
322             }
323             call->SendResponse(ToGrpcStatus(s));
324           });
325     });
326     ENQUEUE_REQUEST(CompleteGroup, true);
327   }
328 
CompleteInstanceHandler(WorkerCall<CompleteInstanceRequest,CompleteInstanceResponse> * call)329   void CompleteInstanceHandler(
330       WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
331     Schedule([this, call]() {
332       CallOptions* call_opts = new CallOptions;
333       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
334       worker_->CompleteInstanceAsync(
335           call_opts, &call->request, &call->response,
336           [call, call_opts](const Status& s) {
337             call->ClearCancelCallback();
338             delete call_opts;
339             if (!s.ok()) {
340               VLOG(1) << "Bad response from CompleteInstance:" << s;
341             }
342             call->SendResponse(ToGrpcStatus(s));
343           });
344     });
345     ENQUEUE_REQUEST(CompleteInstance, false);
346   }
347 #undef ENQUEUE_REQUEST
348 
EnqueueRecvTensorRequestRaw()349   void EnqueueRecvTensorRequestRaw() {
350     mutex_lock l(shutdown_mu_);
351     if (!is_shutdown_) {
352       Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
353            RecvTensorRequest, ::grpc::ByteBuffer>::
354           EnqueueRequestForMethod(
355               worker_service_, cq_.get(),
356               static_cast<int>(GrpcWorkerMethod::kRecvTensor),
357               &GrpcWorkerServiceThread::RecvTensorHandlerRaw,
358               true /* supports cancel*/);
359     }
360   }
361 
362   GrpcWorker* const worker_ = nullptr;  // Not owned.
363   std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
364   std::unique_ptr<Thread> thread_;
365   std::unordered_map<int, int> queue_depth_;
366   GrpcResponseCache* cache_;
367   grpc::WorkerService::AsyncService* const worker_service_;
368 
369   mutex shutdown_mu_;
370   bool is_shutdown_ GUARDED_BY(shutdown_mu_);
371   TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread);
372 };
373 
374 class GrpcWorkerService : public AsyncServiceInterface {
375  public:
GrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)376   GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
377                     GrpcWorkerServiceOptions options)
378       : is_shutdown_(false) {
379     builder->RegisterService(&worker_service_);
380     if (options.response_cache_bytes > 0) {
381       cache_.reset(
382           new GrpcResponseCache(options.response_cache_bytes,
383                                 options.response_cache_expires_seconds));
384     }
385 
386     for (int i = 0; i < options.num_serving_threads; i++) {
387       threads_.emplace_back(
388           new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
389                                       cache_.get(), &worker_service_));
390     }
391   }
392 
Shutdown()393   void Shutdown() override {
394     bool did_shutdown = false;
395     {
396       mutex_lock l(service_shutdown_mu_);
397       if (!is_shutdown_) {
398         LOG(INFO) << "Shutting down GrpcWorkerService.";
399         is_shutdown_ = true;
400         did_shutdown = true;
401       }
402     }
403     if (did_shutdown) {
404       for (auto& worker_thread : threads_) {
405         worker_thread->Shutdown();
406       }
407     }
408   }
409 
410   // This method blocks forever handling requests from the completion queue.
HandleRPCsLoop()411   void HandleRPCsLoop() override {
412     for (auto& worker_thread : threads_) {
413       worker_thread->Start();
414     }
415     for (auto& worker_thread : threads_) {
416       worker_thread->Join();
417     }
418   }
419 
420  private:
421   grpc::WorkerService::AsyncService worker_service_;
422   std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
423 
424   std::unique_ptr<GrpcResponseCache> cache_;
425   mutex service_shutdown_mu_;
426   bool is_shutdown_ GUARDED_BY(service_shutdown_mu_);
427 
428   TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
429 };
430 
431 }  // namespace
432 
GrpcWorker(WorkerEnv * worker_env,const ConfigProto & config)433 GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
434     : Worker(worker_env),
435       recent_request_ids_(100000),
436       recv_buf_max_chunk_(
437           config.experimental().recv_buf_max_chunk() > 0
438               ? config.experimental().recv_buf_max_chunk()
439               : (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {}
440 
441 // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
442 // buffers for a response object, to avoid extra protocol buffer serialization
443 // overhead we generate our response directly into a ::grpc::ByteBuffer object
GrpcRecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,::grpc::ByteBuffer * response,StatusCallback done)444 void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
445                                      const RecvTensorRequest* request,
446                                      ::grpc::ByteBuffer* response,
447                                      StatusCallback done) {
448   Status s = recent_request_ids_.TrackUnique(
449       request->request_id(), "RecvTensor (GrpcWorker)", *request);
450   if (!s.ok()) {
451     done(s);
452     return;
453   }
454 
455   const int64 step_id = request->step_id();
456   const string& key = request->rendezvous_key();
457   TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
458   Rendezvous::ParsedKey parsed;
459   s = Rendezvous::ParseKey(key, &parsed);
460   Device* src_dev = nullptr;
461   if (s.ok()) {
462     s = PrepareRecvTensor(parsed, &src_dev);
463   }
464   if (!s.ok()) {
465     done(s);
466     return;
467   }
468 
469   // Request the tensor associated with the rendezvous key.
470   // Note that we log the cancellation here but do not abort the current step.
471   // gRPC can generate cancellations in response to transient network failures,
472   // and aborting the step eliminates the opportunity for client side retries.
473   // Repeated client failures will eventually cause the step to be aborted by
474   // the client.
475   opts->SetCancelCallback(
476       [step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
477   env_->rendezvous_mgr->RecvLocalAsync(
478       step_id, parsed,
479       [opts, response, done, src_dev, request](
480           const Status& status, const Rendezvous::Args& send_args,
481           const Rendezvous::Args& recv_args, const Tensor& val,
482           const bool is_dead) {
483         opts->ClearCancelCallback();
484         if (status.ok()) {
485           // DMA can only be used for Tensors that do not fall into
486           // the following three odd edge cases: 1) a zero-size
487           // buffer, 2) a dead tensor which has an uninit value, and
488           // 3) the tensor has the on_host allocation attribute,
489           // i.e. it's in CPU RAM *independent of its assigned
490           // device type*.
491           const bool on_host = send_args.alloc_attrs.on_host();
492           {
493             // Non-DMA cases.
494             if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
495               DeviceContext* send_dev_context = send_args.device_context;
496               AllocatorAttributes alloc_attrs;
497               alloc_attrs.set_gpu_compatible(true);
498               alloc_attrs.set_on_host(true);
499               Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
500               Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
501               CHECK(send_dev_context)
502                   << "send dev name: " << src_dev->name()
503                   << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
504               // "val" is on an accelerator device. Uses the device_context to
505               // fill the copy on host.
506               StatusCallback copy_ready = [response, done, copy,
507                                            is_dead](const Status& s) {
508                 // The value is now ready to be returned on the wire.
509                 grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
510                 done(s);
511                 delete copy;
512               };
513 
514               send_dev_context->CopyDeviceTensorToCPU(
515                   &val, request->rendezvous_key(), src_dev, copy, copy_ready);
516             } else {
517               grpc::EncodeTensorToByteBuffer(is_dead, val, response);
518               done(Status::OK());
519             }
520           }
521         } else {
522           //  !s.ok()
523           done(status);
524         }
525       });
526 }
527 
528 namespace {
529 // If RecvBufRespExtra.tensor_content is a single large string, then gRPC
530 // can stall on the recv side when the string buffer needs to be enlarged,
531 // since the size is not sent in advance.  Changing this field to a sequence
532 // of small strings costs some extra time on the send side, since we do
533 // some otherwise unnecessary copies, but it improves runtime overall by
534 // improving flow control.  Best performance is likely achieved with a
535 // max_chunk_bytes equal to the memory page size.
536 //
537 // TODO(tucker): When proto3 supports [ctype=CORD] then change
538 // RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
539 // and remove this function.
SetTensorInRecvBufResp(int64 max_chunk_bytes,const Tensor * tensor,int64 num_bytes,RecvBufResponse * response)540 void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
541                             int64 num_bytes, RecvBufResponse* response) {
542   RecvBufRespExtra extra;
543   const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
544   while (num_bytes > 0) {
545     int64 bytes =
546         max_chunk_bytes > 0 ? std::min(num_bytes, max_chunk_bytes) : num_bytes;
547     extra.add_tensor_content(std::string(head, bytes));
548     head += bytes;
549     num_bytes -= bytes;
550   }
551   response->mutable_transport_options()->PackFrom(extra);
552 }
553 }  // namespace
554 
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)555 void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
556                               RecvBufResponse* response, StatusCallback done) {
557   // This is a generic, low performance implementation appropriate for grpc.
558   Status s = recent_request_ids_.TrackUnique(request->request_id(),
559                                              "RecvBuf (GrpcWorker)", *request);
560   if (!s.ok()) {
561     done(s);
562     return;
563   }
564   CollectiveExecutor::Handle ce_handle(
565       env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
566   CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
567   rma->buf_rendezvous()->ConsumeBuf(
568       request->buf_rendezvous_key(),
569       [this, request, response, done](const Status& status,
570                                       BufRendezvous::Hook* hook) {
571         Status s = status;
572         if (s.ok()) {
573           if (!DMAHelper::CanUseDMA(hook->prod_value)) {
574             s = errors::Internal("Tensor value for key ",
575                                  request->buf_rendezvous_key(),
576                                  " is not of a type supported by RecvBuf");
577           }
578         }
579         if (s.ok()) {
580           // The RPC source tensor needs to be in CPU RAM.  If not already
581           // there make a copy using memory appropriate to the purpose.
582           const size_t num_bytes = hook->prod_value->TotalBytes();
583           const bool on_host =
584               hook->prod_dev->attributes().device_type() == "CPU" ||
585               hook->prod_attr.on_host();
586           if ((!on_host) && (num_bytes > 0)) {
587             Device* cpu_dev = nullptr;
588             s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
589             if (s.ok()) {
590               AllocatorAttributes cpu_attr;
591               cpu_attr.set_gpu_compatible(true);
592               cpu_attr.set_nic_compatible(true);
593               Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
594                                               hook->prod_value->dtype(),
595                                               hook->prod_value->shape());
596               hook->prod_ctx->CopyDeviceTensorToCPU(
597                   hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
598                   [this, num_bytes, response, done, hook,
599                    cpu_tensor](const Status& s) {
600                     if (s.ok()) {
601                       SetTensorInRecvBufResp(recv_buf_max_chunk_, cpu_tensor,
602                                              num_bytes, response);
603                     }
604                     response->set_send_start_micros(env_->env->NowMicros());
605                     done(s);
606                     BufRendezvous::DoneWithHook(hook);
607                     delete cpu_tensor;
608                   });
609               return;
610             }
611           } else {
612             // Tensor is on CPU.
613             SetTensorInRecvBufResp(recv_buf_max_chunk_, hook->prod_value,
614                                    num_bytes, response);
615           }
616         }
617         response->set_send_start_micros(env_->env->NowMicros());
618         done(s);
619         BufRendezvous::DoneWithHook(hook);
620       });
621 }
622 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)623 void GrpcWorker::LoggingAsync(const LoggingRequest* request,
624                               LoggingResponse* response, StatusCallback done) {
625   auto env = this->env();
626   if (env) {
627     auto session_mgr = env->session_mgr;
628     if (session_mgr) {
629       if (request->enable_rpc_logging()) {
630         session_mgr->SetLogging(true);
631       }
632       // NOTE(mrry): Handle old masters that disable RPC logging by setting
633       // `request->enable_rpc_logging` to `false`.
634       if (request->disable_rpc_logging() ||
635           (!request->enable_rpc_logging() &&
636            request->fetch_step_id_size() == 0)) {
637         session_mgr->SetLogging(false);
638       }
639       for (const auto& step_id : request->fetch_step_id()) {
640         session_mgr->RetrieveLogs(step_id, response);
641       }
642       if (request->clear()) {
643         session_mgr->ClearLogs();
644       }
645     }
646   }
647   done(Status::OK());
648 }
649 
env()650 WorkerEnv* GrpcWorker::env() { return env_; }
651 
NewGrpcWorker(WorkerEnv * env,const ConfigProto & config)652 std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
653                                           const ConfigProto& config) {
654   return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
655 }
656 
NewGrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)657 std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
658     GrpcWorker* worker, ::grpc::ServerBuilder* builder,
659     GrpcWorkerServiceOptions options) {
660   return std::unique_ptr<AsyncServiceInterface>(
661       new GrpcWorkerService(worker, builder, options));
662 }
663 
664 }  // namespace tensorflow
665