1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
17 
18 #include <limits>
19 #include <map>
20 #include <unordered_map>
21 
22 #include "grpcpp/create_channel.h"
23 
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/map_util.h"
27 #include "tensorflow/core/lib/strings/numbers.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36 
37 namespace tensorflow {
38 
39 namespace {
40 
MakeAddress(const string & job,int task)41 string MakeAddress(const string& job, int task) {
42   return strings::StrCat("/job:", job, "/replica:0/task:", task);
43 }
44 
45 // Allows the host to be a raw IP (either v4 or v6).
ValidateHostPortPair(const string & host_port)46 Status ValidateHostPortPair(const string& host_port) {
47   uint32 port;
48   auto colon_index = host_port.find_last_of(':');
49   if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
50       host_port.substr(0, colon_index).find("/") != string::npos) {
51     return errors::InvalidArgument("Could not interpret \"", host_port,
52                                    "\" as a host-port pair.");
53   }
54   return Status::OK();
55 }
56 
57 }  // namespace
58 
GetChannelArguments(const RPCOptions * rpc_options)59 ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
60   // TODO(mrry): Implement secure channels.
61   ::grpc::ChannelArguments args;
62   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
63   // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
64   // on connection failure, which makes our tests time out.
65   args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
66   if (rpc_options != nullptr) {
67     if (rpc_options->compression_algorithm() == "deflate") {
68       args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
69       args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
70                   rpc_options->compression_level());
71       VLOG(5) << "Setting GRPC compression : algo='"
72               << rpc_options->compression_algorithm()
73               << "' level=" << rpc_options->compression_level();
74     } else if (rpc_options->compression_algorithm() == "gzip") {
75       args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
76       args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
77                   rpc_options->compression_level());
78       VLOG(5) << "Setting GRPC compression : algo='"
79               << rpc_options->compression_algorithm()
80               << "' level=" << rpc_options->compression_level();
81     } else if (!rpc_options->compression_algorithm().empty()) {
82       LOG(ERROR) << "Invalid compression algorithm: "
83                  << rpc_options->compression_algorithm();
84     }
85   }
86   return args;
87 }
88 
NewHostPortGrpcChannel(const string & target,const RPCOptions * rpc_options,SharedGrpcChannelPtr * channel_pointer)89 Status NewHostPortGrpcChannel(const string& target,
90                               const RPCOptions* rpc_options,
91                               SharedGrpcChannelPtr* channel_pointer) {
92   // Minimally ensure that the target is valid
93   TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
94 
95   ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
96   *channel_pointer = ::grpc::CreateCustomChannel(
97       "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
98   return Status::OK();
99 }
100 
ConvertToChannelCreationFunction(const std::function<Status (string,const RPCOptions *,SharedGrpcChannelPtr *)> & new_channel_func_ptr)101 ChannelCreationFunction ConvertToChannelCreationFunction(
102     const std::function<Status(string, const RPCOptions*,
103                                SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
104   return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
105     SharedGrpcChannelPtr channel_ptr;
106     if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
107             .ok()) {
108       return channel_ptr;
109     } else {
110       return nullptr;
111     }
112   };
113 }
114 
AddHostPortsJob(const string & job_id,const std::vector<string> & host_ports)115 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
116                                         const std::vector<string>& host_ports) {
117   std::map<int, string> host_ports_map;
118   for (size_t i = 0; i < host_ports.size(); ++i) {
119     host_ports_map[i] = host_ports[i];
120   }
121   return AddHostPortsJob(job_id, host_ports_map);
122 }
123 
AddHostPortsJob(const string & job_id,const std::map<int,string> & host_ports)124 Status GrpcChannelSpec::AddHostPortsJob(
125     const string& job_id, const std::map<int, string>& host_ports) {
126   if (!job_ids_.insert(job_id).second) {
127     return errors::InvalidArgument(
128         "Duplicate job ID in cluster specification: ", job_id);
129   }
130   for (const auto& id_host_port : host_ports) {
131     TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
132   }
133   host_ports_jobs_.emplace_back(job_id, host_ports);
134   return Status::OK();
135 }
136 
137 namespace {
138 
139 // GrpcChannelCache that caches results to FindWorkerChannel() calls.
140 class CachingGrpcChannelCache : public GrpcChannelCache {
141  public:
CachingGrpcChannelCache()142   CachingGrpcChannelCache() {}
143 
~CachingGrpcChannelCache()144   ~CachingGrpcChannelCache() override {}
145 
FindWorkerChannel(const string & target)146   SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
147     SharedGrpcChannelPtr ch = nullptr;
148     {
149       mutex_lock l(mu_);  // could use reader lock
150       ch = gtl::FindPtrOrNull(channels_, target);
151       if (ch) {
152         return ch;
153       }
154     }
155     ch = FindChannelOnce(target);
156     if (ch) {
157       mutex_lock l(mu_);
158       channels_.insert({target, ch});
159     }
160     return ch;
161   }
162 
163  protected:
164   // Find the ClientChannel for "target".  Only called when no channel was
165   // found in the channels_ cache for "target".  A non nullptr result will be
166   // cached in channels_.
167   virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
168 
169  private:
170   // TODO(zhifengc): Eviction when the map becomes too big.
171   mutex mu_;
172   std::unordered_map<string, SharedGrpcChannelPtr> channels_ GUARDED_BY(mu_);
173 };
174 
175 // A ChannelCache that is the union of multiple ChannelCaches.
176 // Takes ownership of the caches passed to the constructor.
177 class MultiGrpcChannelCache : public CachingGrpcChannelCache {
178  public:
MultiGrpcChannelCache(const std::vector<GrpcChannelCache * > & caches)179   explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
180       : CachingGrpcChannelCache(), caches_(caches) {}
181 
~MultiGrpcChannelCache()182   ~MultiGrpcChannelCache() override {
183     for (GrpcChannelCache* cache : caches_) {
184       delete cache;
185     }
186   }
187 
ListWorkers(std::vector<string> * workers)188   void ListWorkers(std::vector<string>* workers) override {
189     for (GrpcChannelCache* cache : caches_) {
190       cache->ListWorkers(workers);
191     }
192   }
193 
ListWorkersInJob(const string & job_name,std::vector<string> * workers)194   void ListWorkersInJob(const string& job_name,
195                         std::vector<string>* workers) override {
196     for (GrpcChannelCache* cache : caches_) {
197       cache->ListWorkersInJob(job_name, workers);
198     }
199   }
200 
TranslateTask(const string & target)201   string TranslateTask(const string& target) override {
202     mutex_lock l(mu_);  // could use reader lock
203     GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
204     if (cache == nullptr) {
205       for (GrpcChannelCache* c : caches_) {
206         string r = c->TranslateTask(target);
207         if (!r.empty()) {
208           target_caches_.insert({target, c});
209           cache = c;
210           break;
211         }
212       }
213     }
214     CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
215                  << target;
216     return cache->TranslateTask(target);
217   }
218 
219  protected:
FindChannelOnce(const string & target)220   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
221     for (GrpcChannelCache* cache : caches_) {
222       SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
223       if (ch) {
224         mutex_lock l(mu_);
225         target_caches_.insert({target, cache});
226         return ch;
227       }
228     }
229     return nullptr;
230   }
231 
232  private:
233   // List of channels used by this MultiGrpcChannelCache.
234   const std::vector<GrpcChannelCache*> caches_;
235 
236   mutex mu_;
237   // Cache of channels keyed by the target they are handling.
238   // The same GrpcChannelCache can appear multiple times in the cache.
239   std::unordered_map<string, GrpcChannelCache*> target_caches_ GUARDED_BY(mu_);
240 };
241 
242 class SparseGrpcChannelCache : public CachingGrpcChannelCache {
243  public:
SparseGrpcChannelCache(const string & job_id,const std::map<int,string> & host_ports,ChannelCreationFunction channel_func)244   SparseGrpcChannelCache(const string& job_id,
245                          const std::map<int, string>& host_ports,
246                          ChannelCreationFunction channel_func)
247       : job_id_(job_id),
248         host_ports_(host_ports),
249         channel_func_(std::move(channel_func)) {
250     LOG(INFO) << "Initialize GrpcChannelCache for job " << ToString();
251   }
~SparseGrpcChannelCache()252   ~SparseGrpcChannelCache() override {}
253 
ListWorkers(std::vector<string> * workers)254   void ListWorkers(std::vector<string>* workers) override {
255     workers->reserve(workers->size() + host_ports_.size());
256     for (const auto& id_host_port : host_ports_) {
257       workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
258     }
259   }
260 
ListWorkersInJob(const string & job_name,std::vector<string> * workers)261   void ListWorkersInJob(const string& job_name,
262                         std::vector<string>* workers) override {
263     if (job_name == job_id_) {
264       ListWorkers(workers);
265     }
266   }
267 
TranslateTask(const string & target)268   string TranslateTask(const string& target) override {
269     DeviceNameUtils::ParsedName parsed;
270     if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
271       LOG(WARNING) << "Invalid target: " << target;
272       return "";
273     }
274 
275     if (!parsed.has_job || parsed.job != job_id_) {
276       return "";
277     }
278     if (!parsed.has_replica || parsed.replica != 0) {
279       LOG(WARNING) << "Replica ID must be 0 in target: " << target;
280       return "";
281     }
282     int32 task = parsed.has_task ? parsed.task : -1;
283     auto iter = host_ports_.find(task);
284     if (iter == host_ports_.end()) {
285       LOG(WARNING) << "Task " << task << " was not defined in sparse job "
286                    << job_id_ << ": " << target;
287       return "";
288     }
289     return iter->second;
290   }
291 
292  protected:
FindChannelOnce(const string & target)293   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
294     const string host_port = TranslateTask(target);
295     if (host_port.empty()) {
296       return nullptr;
297     }
298     return channel_func_(host_port);
299   }
300 
301  private:
ToString()302   string ToString() {
303     std::vector<string> task_strings;
304     task_strings.reserve(host_ports_.size());
305     for (const auto& id_host_port : host_ports_) {
306       task_strings.emplace_back(
307           strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
308     }
309     return strings::StrCat(job_id_, " -> {", str_util::Join(task_strings, ", "),
310                            "}");
311   }
312 
313   const string job_id_;
314   const std::map<int, string> host_ports_;
315   const ChannelCreationFunction channel_func_;
316   TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
317 };
318 
319 }  // namespace
320 
NewGrpcChannelCache(const GrpcChannelSpec & spec,ChannelCreationFunction channel_func)321 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
322                                       ChannelCreationFunction channel_func) {
323   const int num_jobs = spec.host_ports_jobs().size();
324   if (!num_jobs) {
325     LOG(ERROR) << "Empty channel spec.";
326     return nullptr;
327   }
328   std::vector<GrpcChannelCache*> caches;
329   caches.reserve(num_jobs);
330   for (auto& job : spec.host_ports_jobs()) {
331     caches.push_back(
332         new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func));
333   }
334   return caches.size() == 1 ? caches[0] : new MultiGrpcChannelCache(caches);
335 }
336 
337 }  // end namespace tensorflow
338