1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/server_lib.h"
17 
18 #include "tensorflow/core/data/service/credentials_factory.h"
19 #include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
20 #include "tensorflow/core/data/service/grpc_util.h"
21 #include "tensorflow/core/data/service/grpc_worker_impl.h"
22 #include "tensorflow/core/platform/errors.h"
23 
24 namespace tensorflow {
25 namespace data {
26 
27 namespace {
28 constexpr char kPortPlaceholder[] = "%port%";
29 }
30 
GrpcDataServerBase(int port,const std::string & protocol,const std::string server_type)31 GrpcDataServerBase::GrpcDataServerBase(int port, const std::string& protocol,
32                                        const std::string server_type)
33     : requested_port_(port),
34       protocol_(protocol),
35       server_type_(server_type),
36       bound_port_(port) {}
37 
Start()38 Status GrpcDataServerBase::Start() {
39   if (stopped_) {
40     return errors::FailedPrecondition(
41         "Server cannot be started after it has been stopped.");
42   }
43   if (started_) {
44     return Status::OK();
45   }
46   ::grpc::ServerBuilder builder;
47   std::shared_ptr<::grpc::ServerCredentials> credentials;
48   TF_RETURN_IF_ERROR(
49       CredentialsFactory::CreateServerCredentials(protocol_, &credentials));
50   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
51                            credentials, &bound_port_);
52   builder.SetMaxReceiveMessageSize(-1);
53 
54   AddDataServiceToBuilder(builder);
55   AddProfilerServiceToBuilder(builder);
56   server_ = builder.BuildAndStart();
57   if (!server_) {
58     return errors::Internal("Could not start gRPC server");
59   }
60 
61   TF_RETURN_IF_ERROR(StartServiceInternal());
62 
63   started_ = true;
64   LOG(INFO) << "Started tf.data " << server_type_
65             << " running at 0.0.0.0:" << BoundPort();
66   return Status::OK();
67 }
68 
Stop()69 void GrpcDataServerBase::Stop() {
70   if (stopped_) {
71     return;
72   }
73   server_->Shutdown();
74   stopped_ = true;
75   LOG(INFO) << "Shut down " << server_type_ << " server running at port "
76             << BoundPort();
77 }
78 
Join()79 void GrpcDataServerBase::Join() { server_->Wait(); }
80 
BoundPort()81 int GrpcDataServerBase::BoundPort() { return bound_port(); }
82 
AddProfilerServiceToBuilder(::grpc::ServerBuilder & builder)83 void GrpcDataServerBase::AddProfilerServiceToBuilder(
84     ::grpc::ServerBuilder& builder) {
85   profiler_service_ = profiler::CreateProfilerService();
86   builder.RegisterService(profiler_service_.get());
87 }
88 
DispatchGrpcDataServer(const experimental::DispatcherConfig & config)89 DispatchGrpcDataServer::DispatchGrpcDataServer(
90     const experimental::DispatcherConfig& config)
91     : GrpcDataServerBase(config.port(), config.protocol(), "DispatchServer"),
92       config_(config) {}
93 
~DispatchGrpcDataServer()94 DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
95 
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)96 void DispatchGrpcDataServer::AddDataServiceToBuilder(
97     ::grpc::ServerBuilder& builder) {
98   service_ = absl::make_unique<GrpcDispatcherImpl>(config_, builder).release();
99 }
100 
StartServiceInternal()101 Status DispatchGrpcDataServer::StartServiceInternal() {
102   return service_->Start();
103 }
104 
NumWorkers(int * num_workers)105 Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
106   GetWorkersRequest req;
107   GetWorkersResponse resp;
108   ::grpc::ServerContext ctx;
109   ::grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
110   if (!s.ok()) {
111     return grpc_util::WrapError("Failed to get workers", s);
112   }
113   *num_workers = resp.workers_size();
114   return Status::OK();
115 }
116 
WorkerGrpcDataServer(const experimental::WorkerConfig & config)117 WorkerGrpcDataServer::WorkerGrpcDataServer(
118     const experimental::WorkerConfig& config)
119     : GrpcDataServerBase(config.port(), config.protocol(), "WorkerServer"),
120       config_(config) {}
121 
~WorkerGrpcDataServer()122 WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
123 
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)124 void WorkerGrpcDataServer::AddDataServiceToBuilder(
125     ::grpc::ServerBuilder& builder) {
126   service_ = absl::make_unique<GrpcWorkerImpl>(config_, builder).release();
127 }
128 
StartServiceInternal()129 Status WorkerGrpcDataServer::StartServiceInternal() {
130   std::string base_address = config_.worker_address();
131   if (base_address.empty()) {
132     base_address = absl::StrCat("localhost:", kPortPlaceholder);
133   }
134   std::string worker_address = str_util::StringReplace(
135       base_address, kPortPlaceholder, absl::StrCat(bound_port()),
136       /*replace_all=*/false);
137   std::string transfer_address = worker_address;
138   std::string transfer_protocol = config_.data_transfer_protocol();
139   if (!transfer_protocol.empty()) {
140     TF_RETURN_IF_ERROR(DataTransferServer::Build(
141         transfer_protocol, service_->get_element_getter(), &transfer_server_));
142     TF_RETURN_IF_ERROR(transfer_server_->Start());
143     LOG(INFO) << "Data transfer server started at 0.0.0.0:"
144               << transfer_server_->get_port();
145     transfer_address =
146         str_util::StringReplace(base_address, kPortPlaceholder,
147                                 absl::StrCat(transfer_server_->get_port()),
148                                 /*replace_all=*/false);
149   }
150   TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_address));
151   return Status::OK();
152 }
153 
NumTasks(int * num_tasks)154 Status WorkerGrpcDataServer::NumTasks(int* num_tasks) {
155   GetWorkerTasksRequest req;
156   GetWorkerTasksResponse resp;
157   ::grpc::ServerContext ctx;
158   ::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp);
159   if (!s.ok()) {
160     return grpc_util::WrapError("Failed to get tasks", s);
161   }
162   *num_tasks = resp.tasks_size();
163   return Status::OK();
164 }
165 
NewDispatchServer(const experimental::DispatcherConfig & config,std::unique_ptr<DispatchGrpcDataServer> & out_server)166 Status NewDispatchServer(const experimental::DispatcherConfig& config,
167                          std::unique_ptr<DispatchGrpcDataServer>& out_server) {
168   out_server = absl::make_unique<DispatchGrpcDataServer>(config);
169   return Status::OK();
170 }
171 
NewWorkerServer(const experimental::WorkerConfig & config,std::unique_ptr<WorkerGrpcDataServer> & out_server)172 Status NewWorkerServer(const experimental::WorkerConfig& config,
173                        std::unique_ptr<WorkerGrpcDataServer>& out_server) {
174   out_server = absl::make_unique<WorkerGrpcDataServer>(config);
175   return Status::OK();
176 }
177 
178 }  // namespace data
179 }  // namespace tensorflow
180