1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
18 
19 // GrpcServer manages the lifecycle of an Eager, Worker and Master service.
20 
21 #include <memory>
22 
23 #include "grpcpp/grpcpp.h"
24 #include "grpcpp/security/credentials.h"
25 
26 #include "tensorflow/core/common_runtime/process_util.h"
27 #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
28 #include "tensorflow/core/distributed_runtime/master_env.h"
29 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
30 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
32 #include "tensorflow/core/distributed_runtime/server_lib.h"
33 #include "tensorflow/core/distributed_runtime/session_mgr.h"
34 #include "tensorflow/core/distributed_runtime/worker_env.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/platform/env.h"
38 
39 namespace tensorflow {
40 
41 class GrpcWorker;
42 class Master;
43 
44 // function that creates a RendezvousMgr.
45 typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
46     RendezvousMgrCreationFunction;
47 
48 // function that creates a CollectiveExecutorMgr.
49 typedef std::function<CollectiveExecutorMgrInterface*(
50     const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
51     CollectiveMgrCreationFunction;
52 
53 // function that registers a service to the server. The service needs to
54 // be registered before builder.BuildAndStart().
55 typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
56     ServiceInitFunction;
57 
58 // function that creates a grpc based worker implementation.
59 typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*,
60                                                   const ConfigProto& config)>
61     WorkerCreationFunction;
62 
63 struct GrpcServerOptions {
64   ServiceInitFunction service_func = nullptr;
65   RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr;
66   CollectiveMgrCreationFunction collective_mgr_func = nullptr;
67   WorkerCreationFunction worker_func = nullptr;
68   StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher;
69   GrpcWorkerServiceOptions worker_service_options;
70 };
71 
72 class GrpcServer : public ServerInterface {
73  protected:
74   GrpcServer(const ServerDef& server_def, Env* env);
75   // Allow children classes to override this and provide custom args to the
76   // server before it is constructed. Default behavior is to do nothing.
77   virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder);
78 
79  public:
80   static Status Create(const ServerDef& server_def, Env* env,
81                        std::unique_ptr<ServerInterface>* out_server);
82   static Status Create(const ServerDef& server_def, Env* env,
83                        std::unique_ptr<GrpcServer>* out_server);
84 
85   // Destruction is only supported in the factory method. Clean
86   // shutdown is not currently implemented for this server type.
87   virtual ~GrpcServer();
88 
89   // Implementations of ServerInterface methods.
90   Status Start() override;
91   Status Stop() override;
92   Status Join() override;
93   const string target() const override;
94 
worker_env()95   WorkerEnv* worker_env() { return &worker_env_; }
master_env()96   MasterEnv* master_env() { return &master_env_; }
97 
channel_cache()98   std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
99 
100  protected:
101   Status Init(const GrpcServerOptions& opts = GrpcServerOptions());
102 
103   // A subclass can override this method to support secure credentials.
104   virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
105       const ServerDef& server_def) const;
106 
107   virtual ChannelCreationFunction GetChannelCreationFunction() const;
108 
109   virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
110 
111   // Creates a WorkerCacheInterface for a session.
112   Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
113                             WorkerCacheInterface** worker_cache);
114 
115   // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
116   Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
117                           GrpcChannelSpec* channel_spec);
118 
119   // Returns the port to which this server is bound.
120   // This method may only be called after `this->Init()` returns successfully.
bound_port()121   int bound_port() const { return bound_port_; }
122 
server_def()123   const ServerDef& server_def() const { return server_def_; }
124 
125  private:
126   // The overall server configuration.
127   const ServerDef server_def_;
128   Env* env_;
129 
130   // The port to which this server is bound.
131   int bound_port_ = 0;
132 
133   // Guards state transitions.
134   mutex mu_;
135 
136   // Represents the current state of the server, which changes as follows:
137   //
138   //                 Join()            Join()
139   //                  ___               ___
140   //      Start()     \ /    Stop()     \ /
141   // NEW ---------> STARTED --------> STOPPED
142   //   \                          /
143   //    \________________________/
144   //            Stop(), Join()
145   enum State { NEW, STARTED, STOPPED };
146   State state_ GUARDED_BY(mu_);
147 
148   // Implementation of a TensorFlow master, and RPC polling thread.
149   MasterEnv master_env_;
150   std::unique_ptr<Master> master_impl_;
151   AsyncServiceInterface* master_service_ = nullptr;
152   std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_);
153   std::shared_ptr<GrpcChannelCache> channel_cache_;
154 
155   // Implementation of a TensorFlow worker, and RPC polling thread.
156   WorkerEnv worker_env_;
157   std::unique_ptr<GrpcWorker> worker_impl_;
158   AsyncServiceInterface* worker_service_ = nullptr;
159   std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_);
160 
161   // TensorFlow Eager implementation, and RPC polling thread.
162   AsyncServiceInterface* eager_service_ = nullptr;
163   std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_);
164   std::shared_ptr<WorkerSession> worker_session_;
165 
166   std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_);
167 };
168 
169 }  // namespace tensorflow
170 
171 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
172