1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
17 
18 #include <cstring>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "grpc/support/alloc.h"
24 #include "grpcpp/grpcpp.h"
25 #include "grpcpp/security/credentials.h"
26 #include "grpcpp/server_builder.h"
27 
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
32 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
33 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
34 #include "tensorflow/core/distributed_runtime/local_master.h"
35 #include "tensorflow/core/distributed_runtime/master.h"
36 #include "tensorflow/core/distributed_runtime/master_env.h"
37 #include "tensorflow/core/distributed_runtime/master_session.h"
38 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
39 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
42 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
43 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
44 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
45 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
46 #include "tensorflow/core/distributed_runtime/server_lib.h"
47 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
48 #include "tensorflow/core/distributed_runtime/worker_env.h"
49 #include "tensorflow/core/framework/op.h"
50 #include "tensorflow/core/lib/strings/strcat.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/mem.h"
53 #include "tensorflow/core/public/session_options.h"
54 
55 namespace tensorflow {
56 
57 namespace {
58 
59 // Define an option subclass in order to disable SO_REUSEPORT for the
60 // server socket.
61 class NoReusePortOption : public ::grpc::ServerBuilderOption {
62  public:
UpdateArguments(::grpc::ChannelArguments * args)63   void UpdateArguments(::grpc::ChannelArguments* args) override {
64     args->SetInt(GRPC_ARG_ALLOW_REUSEPORT, 0);
65   }
66 
UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>> * plugins)67   void UpdatePlugins(std::vector<std::unique_ptr<::grpc::ServerBuilderPlugin>>*
68                          plugins) override {}
69 };
70 
71 // static utility function
NewRpcRendezvousMgr(const WorkerEnv * env)72 RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
73   return new RpcRendezvousMgr(env);
74 }
75 
76 }  // namespace
77 
GrpcServer(const ServerDef & server_def,Env * env)78 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
79     : server_def_(server_def), env_(env), state_(NEW) {}
80 
~GrpcServer()81 GrpcServer::~GrpcServer() {
82   TF_CHECK_OK(Stop());
83   TF_CHECK_OK(Join());
84 
85   delete master_service_;
86   delete worker_service_;
87   delete eager_service_;
88 
89   // TODO(mrry): Refactor the *Env classes so that it is less fiddly
90   // to destroy them.
91 
92   // Shut down all outstanding rendezvous.
93   delete worker_env_.rendezvous_mgr;
94 
95   // We must delete graph_mgr before device_mgr, due to shared
96   // ownership of OpKernels in the executors. (The graph_mgr will
97   // free all stateless OpKernels, and pass over borrowed stateful
98   // OpKernels, which are also held in their respective devices'
99   // OpSegments.)
100   if (worker_env_.session_mgr != nullptr) {
101     delete worker_env_.session_mgr;  // Deletes graph_mgr's.
102   } else {
103     // Note: session_mgr's legacy_session_ deletes device_mgr now.
104     delete worker_env_.device_mgr;
105   }
106 
107   // Do not delete (as these are not owned by the server):
108   // - master_env_.env
109   // - worker_env_.env
110   // - worker_env_.compute_pool
111 }
112 
MaybeMutateBuilder(::grpc::ServerBuilder * builder)113 void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
114 
Init(const GrpcServerOptions & opts)115 Status GrpcServer::Init(const GrpcServerOptions& opts) {
116   mutex_lock l(mu_);
117   CHECK_EQ(state_, NEW);
118   master_env_.env = env_;
119   worker_env_.env = env_;
120 
121   // Check parameters before DeviceFactory::AddDevices,
122   // otherwise if 'task_index=-1' the program will abort.
123 
124   // Look up the port that has been requested for this task in `server_def_`.
125   int requested_port = -1;
126   for (const auto& job : server_def_.cluster().job()) {
127     if (job.name() == server_def_.job_name()) {
128       auto iter = job.tasks().find(server_def_.task_index());
129       if (iter == job.tasks().end()) {
130         return errors::InvalidArgument("Task ", server_def_.task_index(),
131                                        " was not defined in job \"",
132                                        server_def_.job_name(), "\"");
133       }
134       auto colon_index = iter->second.find_last_of(':');
135       if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
136                                  &requested_port)) {
137         return errors::InvalidArgument(
138             "Could not parse port for local server from \"", iter->second,
139             "\".");
140       }
141       break;
142     }
143   }
144   if (requested_port == -1) {
145     return errors::Internal("Job \"", server_def_.job_name(),
146                             "\" was not defined in cluster");
147   }
148 
149   SessionOptions sess_opts;
150   ConfigProto config = server_def_.default_session_config();
151   sess_opts.config = config;
152 
153   // Configure shared devices between master and worker.
154   string name_prefix =
155       strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
156                       "/task:", server_def_.task_index());
157   std::vector<std::unique_ptr<Device>> devices;
158   TF_RETURN_IF_ERROR(
159       DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
160   worker_env_.device_mgr = new DeviceMgr(std::move(devices));
161   master_env_.local_devices = worker_env_.device_mgr->ListDevices();
162   worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
163   worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
164                                    ? new RpcRendezvousMgr(&worker_env_)
165                                    : opts.rendezvous_mgr_func(&worker_env_);
166   string unused;
167   string default_worker_name;
168   if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
169                                         &default_worker_name, &unused)) {
170     return errors::Internal("Could not parse worker name.");
171   }
172 
173   // N.B. The order of initialization here is intricate, because we
174   // wish to allow `requested_port == 0` (for choosing any port,
175   // mostly for testing). Therefore, the construction of the channel
176   // and worker caches depends on `bound_port_`, which is not set
177   // until we call `builder.BuildAndStart()`. We must create the
178   // service objects before calling `builder.BuildAndStart()`, but
179   // `master_env_` and `worker_env_` are only partially
180   // configured. However, this is not dangerous, because we do not
181   // start serving requests until `this->Start()` is called, which
182   // happens after this method returns.
183   //
184   // TODO(mrry): Provide a general mechanism for dynamically setting
185   // the identities of tasks in the worker pool after the service is
186   // running.
187   ::grpc::ServerBuilder builder;
188   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port),
189                            GetServerCredentials(server_def_), &bound_port_);
190   builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
191 
192   builder.SetOption(
193       std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
194   // Allow subclasses to specify more args to pass to the gRPC server.
195   MaybeMutateBuilder(&builder);
196   master_impl_ = CreateMaster(&master_env_);
197   master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
198   worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
199                                   : NewGrpcWorker(&worker_env_, config);
200   worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
201                                          opts.worker_service_options)
202                         .release();
203   eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
204 
205   // extra service:
206   if (opts.service_func != nullptr) {
207     opts.service_func(&worker_env_, &builder);
208   }
209   server_ = builder.BuildAndStart();
210 
211   if (!server_) {
212     return errors::Unknown("Could not start gRPC server");
213   }
214 
215   WorkerCacheInterface* worker_cache;
216   WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
217   TF_RETURN_IF_ERROR(
218       WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
219   CHECK_NE(nullptr, worker_cache);
220 
221   if (opts.collective_mgr_func) {
222     worker_env_.collective_executor_mgr =
223         opts.collective_mgr_func(config, &worker_env_, worker_cache);
224     if (!worker_env_.collective_executor_mgr) {
225       return errors::Internal(
226           "collective_mgr_func did not return CollectiveExecutorMgr");
227     }
228   } else {
229     std::unique_ptr<DeviceResolverDistributed> dev_resolver(
230         new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache,
231                                       default_worker_name));
232     std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
233         new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
234                                                dev_resolver.get(), worker_cache,
235                                                default_worker_name));
236     worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr(
237         config, worker_env_.device_mgr, std::move(dev_resolver),
238         std::move(param_resolver), worker_cache, default_worker_name);
239   }
240 
241   // Set up worker environment.
242   worker_env_.session_mgr = new SessionMgr(
243       &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
244       std::unique_ptr<WorkerCacheInterface>(worker_cache),
245       [this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
246         WorkerCacheFactoryOptions options(server_def);
247         return WorkerCacheFactory(options, worker_cache);
248       });
249   worker_env_.compute_pool = ComputePool(sess_opts);
250 
251   // Finish setting up master environment.
252   master_env_.ops = OpRegistry::Global();
253   master_env_.worker_cache = worker_cache;
254   master_env_.collective_executor_mgr = worker_env_.collective_executor_mgr;
255   StatsPublisherFactory stats_factory = opts.stats_factory;
256   master_env_.master_session_factory =
257       [config, stats_factory](
258           SessionOptions options, const MasterEnv* env,
259           std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
260           std::unique_ptr<WorkerCacheInterface> worker_cache,
261           std::unique_ptr<DeviceSet> device_set,
262           std::vector<string> filtered_worker_list) {
263         options.config.MergeFrom(config);
264         return new MasterSession(options, env, std::move(remote_devs),
265                                  std::move(worker_cache), std::move(device_set),
266                                  std::move(filtered_worker_list),
267                                  stats_factory);
268       };
269   master_env_.worker_cache_factory =
270       [this](const WorkerCacheFactoryOptions& options,
271              WorkerCacheInterface** worker_cache) {
272         return WorkerCacheFactory(options, worker_cache);
273       };
274 
275   // Provide direct access to the master from in-process clients.
276   LocalMaster::Register(target(), master_impl_.get(),
277                         config.operation_timeout_in_ms());
278 
279   return Status::OK();
280 }
281 
ParseChannelSpec(const WorkerCacheFactoryOptions & options,GrpcChannelSpec * channel_spec)282 Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
283                                     GrpcChannelSpec* channel_spec) {
284   for (const auto& job : options.cluster_def->job()) {
285     std::map<int, string> host_ports;
286     for (const auto& task : job.tasks()) {
287       string& host_port = host_ports[task.first];
288       if (!host_port.empty()) {
289         return errors::InvalidArgument("JobDef for job \"", job.name(),
290                                        "\" specified two addresses for task \"",
291                                        task.first, "\": ", host_port, " and ",
292                                        task.second);
293       }
294       if (job.name() == *options.job_name && task.first == options.task_index) {
295         host_port = strings::StrCat("localhost:", bound_port_);
296       } else {
297         host_port = task.second;
298       }
299     }
300     TF_RETURN_IF_ERROR(channel_spec->AddHostPortsJob(job.name(), host_ports));
301   }
302   return Status::OK();
303 }
304 
WorkerCacheFactory(const WorkerCacheFactoryOptions & options,WorkerCacheInterface ** worker_cache)305 Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
306                                       WorkerCacheInterface** worker_cache) {
307   if (options.job_name == nullptr || options.job_name->empty()) {
308     Status s = errors::InvalidArgument(
309         "The master (current machine) is not included in the provided "
310         "cluster_def. ",
311         options.cluster_def->DebugString());
312     LOG(WARNING) << s;
313     return s;
314   }
315 
316   GrpcChannelSpec channel_spec;
317   TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
318 
319   channel_cache_.reset(
320       NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
321 
322   string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
323                                        "/task:", options.task_index);
324 
325   const string host_port = channel_cache_->TranslateTask(name_prefix);
326   int requested_port;
327 
328   auto colon_index = host_port.find_last_of(':');
329   if (!strings::safe_strto32(host_port.substr(colon_index + 1),
330                              &requested_port)) {
331     return errors::Internal("Could not parse port for local server from \"",
332                             host_port, "\".");
333   }
334 
335   if (requested_port != bound_port_) {
336     return errors::InvalidArgument("Requested port ", requested_port,
337                                    " differs from expected port ", bound_port_);
338   }
339 
340   *worker_cache = NewGrpcWorkerCacheWithLocalWorker(
341       channel_cache_, worker_impl_.get(), name_prefix);
342   return Status::OK();
343 }
344 
Start()345 Status GrpcServer::Start() {
346   mutex_lock l(mu_);
347   switch (state_) {
348     case NEW: {
349       master_thread_.reset(
350           env_->StartThread(ThreadOptions(), "TF_master_service",
351                             [this] { master_service_->HandleRPCsLoop(); }));
352       worker_thread_.reset(
353           env_->StartThread(ThreadOptions(), "TF_worker_service",
354                             [this] { worker_service_->HandleRPCsLoop(); }));
355       eager_thread_.reset(
356           env_->StartThread(ThreadOptions(), "TF_eager_service",
357                             [this] { eager_service_->HandleRPCsLoop(); }));
358       state_ = STARTED;
359       LOG(INFO) << "Started server with target: " << target();
360       return Status::OK();
361     }
362     case STARTED:
363       LOG(INFO) << "Server already started (target: " << target() << ")";
364       return Status::OK();
365     case STOPPED:
366       return errors::FailedPrecondition("Server has stopped.");
367     default:
368       LOG(FATAL);
369   }
370 }
371 
Stop()372 Status GrpcServer::Stop() {
373   mutex_lock l(mu_);
374   switch (state_) {
375     case NEW:
376       state_ = STOPPED;
377       return Status::OK();
378     case STARTED:
379       return errors::Unimplemented(
380           "Clean shutdown is not currently implemented");
381     case STOPPED:
382       LOG(INFO) << "Server already stopped (target: " << target() << ")";
383       return Status::OK();
384     default:
385       LOG(FATAL);
386   }
387 }
388 
Join()389 Status GrpcServer::Join() {
390   mutex_lock l(mu_);
391   switch (state_) {
392     case NEW:
393       // Prevent the server from being started subsequently.
394       state_ = STOPPED;
395       return Status::OK();
396     case STARTED:
397     case STOPPED:
398       master_thread_.reset();
399       worker_thread_.reset();
400       eager_thread_.reset();
401       return Status::OK();
402     default:
403       LOG(FATAL);
404   }
405 }
406 
target() const407 const string GrpcServer::target() const {
408   return strings::StrCat("grpc://localhost:", bound_port_);
409 }
410 
GetServerCredentials(const ServerDef & server_def) const411 std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
412     const ServerDef& server_def) const {
413   return ::grpc::InsecureServerCredentials();
414 }
415 
GetChannelCreationFunction() const416 ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
417   // We can do this because SparseGrpcChannelCache is robust to nullptr being
418   // returned by the channel creation function
419   return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
420 }
421 
CreateMaster(MasterEnv * master_env)422 std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
423   return std::unique_ptr<Master>(new Master(master_env, 0.0));
424 }
425 
426 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<ServerInterface> * out_server)427 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
428                           std::unique_ptr<ServerInterface>* out_server) {
429   std::unique_ptr<GrpcServer> ret(
430       new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
431   ServiceInitFunction service_func = nullptr;
432   GrpcServerOptions options;
433   options.rendezvous_mgr_func = NewRpcRendezvousMgr;
434   Status s = ret->Init(options);
435   if (!s.ok()) {
436     LOG(ERROR) << s;
437     return s;
438   }
439   *out_server = std::move(ret);
440   return Status::OK();
441 }
442 
443 /* static */
Create(const ServerDef & server_def,Env * env,std::unique_ptr<GrpcServer> * out_server)444 Status GrpcServer::Create(const ServerDef& server_def, Env* env,
445                           std::unique_ptr<GrpcServer>* out_server) {
446   std::unique_ptr<GrpcServer> ret(
447       new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
448   GrpcServerOptions options;
449   options.rendezvous_mgr_func = NewRpcRendezvousMgr;
450   Status s = ret->Init(options);
451   if (!s.ok()) {
452     LOG(ERROR) << s;
453     return s;
454   }
455   *out_server = std::move(ret);
456   return Status::OK();
457 }
458 
459 namespace {
460 
461 class GrpcServerFactory : public ServerFactory {
462  public:
AcceptsOptions(const ServerDef & server_def)463   bool AcceptsOptions(const ServerDef& server_def) override {
464     return server_def.protocol() == "grpc";
465   }
466 
NewServer(const ServerDef & server_def,std::unique_ptr<ServerInterface> * out_server)467   Status NewServer(const ServerDef& server_def,
468                    std::unique_ptr<ServerInterface>* out_server) override {
469     return GrpcServer::Create(server_def, Env::Default(), out_server);
470   }
471 };
472 
473 // Registers a `ServerFactory` for `GrpcServer` instances.
474 class GrpcServerRegistrar {
475  public:
GrpcServerRegistrar()476   GrpcServerRegistrar() {
477     gpr_allocation_functions alloc_fns;
478     memset(&alloc_fns, 0, sizeof(alloc_fns));
479     alloc_fns.malloc_fn = port::Malloc;
480     alloc_fns.realloc_fn = port::Realloc;
481     alloc_fns.free_fn = port::Free;
482     gpr_set_allocation_functions(alloc_fns);
483     ServerFactory::Register("GRPC_SERVER", new GrpcServerFactory());
484   }
485 };
486 static GrpcServerRegistrar registrar;
487 
488 }  // namespace
489 }  // namespace tensorflow
490