1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17 
18 #include "grpcpp/generic/generic_stub.h"
19 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/protobuf/eager_service.pb.h"
26 
27 namespace tensorflow {
28 namespace eager {
29 namespace {
30 class GrpcEagerClient : public EagerClient {
31  public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,::grpc::CompletionQueue * cq)32   GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
33                   ::grpc::CompletionQueue* cq)
34       : stub_(channel), cq_(cq) {}
~GrpcEagerClient()35   ~GrpcEagerClient() override {}
36 
37 #define CLIENT_METHOD(method)                                             \
38   void method##Async(const method##Request* request,                      \
39                      method##Response* response, StatusCallback done)     \
40       override {                                                          \
41     new RPCState<protobuf::Message>(                                      \
42         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
43         response, std::move(done), nullptr, nullptr);                     \
44   }
45 
46   CLIENT_METHOD(CreateContext);
47   CLIENT_METHOD(Enqueue);
48   CLIENT_METHOD(WaitQueueDone);
49   CLIENT_METHOD(KeepAlive);
50   CLIENT_METHOD(CloseContext);
51   CLIENT_METHOD(RegisterFunction);
52   CLIENT_METHOD(SendTensor);
53 
54 #undef CLIENT_METHOD
55 
56  private:
57   ::grpc::GenericStub stub_;
58   ::grpc::CompletionQueue* cq_;
59 };
60 
61 class GrpcEagerClientCache : public EagerClientCache {
62  public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)63   explicit GrpcEagerClientCache(
64       std::shared_ptr<tensorflow::GrpcChannelCache> cache)
65       : next_round_robin_assignment_(0), cache_(cache), threads_(4) {}
66 
~GrpcEagerClientCache()67   ~GrpcEagerClientCache() override { threads_.clear(); }
68 
GetClient(const string & target)69   EagerClient* GetClient(const string& target) override {
70     auto it = clients_.find(target);
71     if (it == clients_.end()) {
72       tensorflow::SharedGrpcChannelPtr shared =
73           cache_->FindWorkerChannel(target);
74       auto worker = std::unique_ptr<EagerClient>(new GrpcEagerClient(
75           shared, threads_[AssignClientToThread(target)].completion_queue()));
76 
77       it = clients_.emplace(target, std::move(worker)).first;
78     }
79 
80     return it->second.get();
81   }
82 
83  private:
84   mutex assignment_mu_;
85   std::unordered_map<std::string, size_t> target_assignments_
86       GUARDED_BY(assignment_mu_);
87   size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_);
88 
AssignClientToThread(const string & target)89   size_t AssignClientToThread(const string& target) {
90     // Round-robin target assignment, but keeps the same target on the same
91     // polling thread always, as this is important for gRPC performace
92     mutex_lock lock(assignment_mu_);
93     auto it = target_assignments_.find(target);
94     if (it == target_assignments_.end()) {
95       it = target_assignments_
96                .insert(std::make_pair(
97                    target, (next_round_robin_assignment_++) % threads_.size()))
98                .first;
99     }
100     return it->second;
101   }
102 
103   class GrpcEagerClientThread {
104    public:
GrpcEagerClientThread()105     GrpcEagerClientThread() {
106       thread_.reset(Env::Default()->StartThread(
107           ThreadOptions(), "eager_client_thread", [this]() {
108             void* tag;
109             bool ok;
110             while (completion_queue_.Next(&tag, &ok)) {
111               GrpcClientCQTag* callback_tag =
112                   static_cast<GrpcClientCQTag*>(tag);
113               callback_tag->OnCompleted(ok);
114             }
115           }));
116     }
117 
~GrpcEagerClientThread()118     ~GrpcEagerClientThread() {
119       completion_queue_.Shutdown();
120       thread_.reset();
121     }
122 
completion_queue()123     ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
124 
125    private:
126     ::grpc::CompletionQueue completion_queue_;
127     std::unique_ptr<Thread> thread_;
128   };  // GrpcEagerClientThread
129 
130   std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
131   std::unordered_map<string, std::unique_ptr<EagerClient>> clients_;
132   std::vector<GrpcEagerClientThread> threads_;
133 };
134 
135 }  // namespace
136 
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)137 EagerClientCache* NewGrpcEagerClientCache(
138     std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
139   return new GrpcEagerClientCache(channel);
140 }
141 
142 }  // namespace eager
143 }  // namespace tensorflow
144