1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
17 
18 #include <unordered_map>
19 
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
25 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
26 #include "tensorflow/core/distributed_runtime/worker_interface.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/mutex.h"
29 
30 namespace tensorflow {
31 
32 namespace {
33 
34 class GrpcWorkerCache : public WorkerCachePartial {
35  public:
36   // TODO(ncteisen): consider adding a config var or flag for this
37   static constexpr const size_t kGrpcWorkerCacheThreadCount = 8;
38 
GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,WorkerInterface * local_worker,const string & local_target)39   explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
40                            WorkerInterface* local_worker,
41                            const string& local_target)
42       : local_target_(local_target),
43         local_worker_(local_worker),
44         channel_cache_(channel_cache),
45         threads_(kGrpcWorkerCacheThreadCount),
46         next_round_robin_assignment_(0) {
47     // NOTE: We don't yet have any reason to assign NUMA affinity to this
48     // ThreadPool.  If there's only a single NIC it shouldn't make any
49     // difference since presumably it is handling memory from all nodes.
50     ThreadOptions options;
51     options.numa_node = port::kNUMANoAffinity;
52     const int kNumCallbackThreads = 10;
53     callback_threadpool_.reset(new thread::ThreadPool(
54         Env::Default(), options, "grpc_wcache_callback", kNumCallbackThreads,
55         false /*low_latency_hint*/, nullptr /*allocator*/));
56   }
57 
58   // Explicit destructor to control destruction order.
~GrpcWorkerCache()59   ~GrpcWorkerCache() override {
60     threads_.clear();  // Blocks until threads exit.
61   }
62 
ListWorkers(std::vector<string> * workers) const63   void ListWorkers(std::vector<string>* workers) const override {
64     channel_cache_->ListWorkers(workers);
65   }
66 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const67   void ListWorkersInJob(const string& job_name,
68                         std::vector<string>* workers) const override {
69     channel_cache_->ListWorkersInJob(job_name, workers);
70   }
71 
CreateWorker(const string & target)72   WorkerInterface* CreateWorker(const string& target) override {
73     if (target == local_target_) {
74       return local_worker_;
75     } else {
76       SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
77       if (!channel) return nullptr;
78       return NewGrpcRemoteWorker(
79           channel, threads_[AssignWorkerToThread(target)].completion_queue(),
80           callback_threadpool_.get(), &logger_);
81     }
82   }
83 
ReleaseWorker(const string & target,WorkerInterface * worker)84   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
85     if (target == local_target_) {
86       CHECK_EQ(worker, local_worker_)
87           << "Releasing a worker that was not returned by this WorkerCache";
88     } else {
89       WorkerCacheInterface::ReleaseWorker(target, worker);
90     }
91   }
92 
SetLogging(bool v)93   void SetLogging(bool v) override { logger_.SetLogging(v); }
94 
ClearLogs()95   void ClearLogs() override { logger_.ClearLogs(); }
96 
RetrieveLogs(int64 step_id,StepStats * ss)97   bool RetrieveLogs(int64 step_id, StepStats* ss) override {
98     return logger_.RetrieveLogs(step_id, ss);
99   }
100 
101  private:
102   // Thread wrapping class that drives work over a single gRPC
103   // CompletionQueue.
104   class GrpcWorkerCacheThread {
105    public:
GrpcWorkerCacheThread()106     GrpcWorkerCacheThread() {
107       thread_.reset(Env::Default()->StartThread(
108           ThreadOptions(), "grpc_worker_cache", [this]() {
109             void* tag;
110             bool ok;
111             while (completion_queue_.Next(&tag, &ok)) {
112               GrpcClientCQTag* callback_tag =
113                   static_cast<GrpcClientCQTag*>(tag);
114               callback_tag->OnCompleted(ok);
115             }
116           }));
117     }
118 
~GrpcWorkerCacheThread()119     ~GrpcWorkerCacheThread() {
120       completion_queue_.Shutdown();
121       thread_.reset();
122     }
123 
completion_queue()124     ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
125 
126    private:
127     ::grpc::CompletionQueue completion_queue_;
128     std::unique_ptr<Thread> thread_;
129   };  // GrpcWorkerCacheThread
130 
AssignWorkerToThread(const string & target)131   size_t AssignWorkerToThread(const string& target) {
132     // Round-robin target assignment, but keeps the same target on the same
133     // polling thread always, as this is important for gRPC performance
134     mutex_lock lock(assignment_mu_);
135     auto it = target_assignments_.find(target);
136     if (it == target_assignments_.end()) {
137       it = target_assignments_
138                .insert(std::make_pair(
139                    target, (next_round_robin_assignment_++) % threads_.size()))
140                .first;
141     }
142     return it->second;
143   }
144 
145   const string local_target_;
146   WorkerInterface* const local_worker_;  // Not owned.
147   std::shared_ptr<GrpcChannelCache> channel_cache_;
148   WorkerCacheLogger logger_;
149   std::vector<GrpcWorkerCacheThread> threads_;
150 
151   std::unique_ptr<thread::ThreadPool> callback_threadpool_;
152 
153   mutex assignment_mu_;
154   std::unordered_map<std::string, size_t> target_assignments_
155       GUARDED_BY(assignment_mu_);
156   size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_);
157 };
158 
159 }  // namespace
160 
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc)161 WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc) {
162   return new GrpcWorkerCache(cc, nullptr, "");
163 }
164 
NewGrpcWorkerCacheWithLocalWorker(std::shared_ptr<GrpcChannelCache> cc,WorkerInterface * local_worker,const string & local_target)165 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
166     std::shared_ptr<GrpcChannelCache> cc, WorkerInterface* local_worker,
167     const string& local_target) {
168   return new GrpcWorkerCache(cc, local_worker, local_target);
169 }
170 
171 }  // namespace tensorflow
172