1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/server_lib.h"
17
18 #include "tensorflow/core/data/service/credentials_factory.h"
19 #include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
20 #include "tensorflow/core/data/service/grpc_util.h"
21 #include "tensorflow/core/data/service/grpc_worker_impl.h"
22 #include "tensorflow/core/platform/errors.h"
23
24 namespace tensorflow {
25 namespace data {
26
27 namespace {
28 constexpr char kPortPlaceholder[] = "%port%";
29 }
30
GrpcDataServerBase(int port,const std::string & protocol,const std::string server_type)31 GrpcDataServerBase::GrpcDataServerBase(int port, const std::string& protocol,
32 const std::string server_type)
33 : requested_port_(port),
34 protocol_(protocol),
35 server_type_(server_type),
36 bound_port_(port) {}
37
Start()38 Status GrpcDataServerBase::Start() {
39 if (stopped_) {
40 return errors::FailedPrecondition(
41 "Server cannot be started after it has been stopped.");
42 }
43 if (started_) {
44 return Status::OK();
45 }
46 ::grpc::ServerBuilder builder;
47 std::shared_ptr<::grpc::ServerCredentials> credentials;
48 TF_RETURN_IF_ERROR(
49 CredentialsFactory::CreateServerCredentials(protocol_, &credentials));
50 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
51 credentials, &bound_port_);
52 builder.SetMaxReceiveMessageSize(-1);
53
54 AddDataServiceToBuilder(builder);
55 AddProfilerServiceToBuilder(builder);
56 server_ = builder.BuildAndStart();
57 if (!server_) {
58 return errors::Internal("Could not start gRPC server");
59 }
60
61 TF_RETURN_IF_ERROR(StartServiceInternal());
62
63 started_ = true;
64 LOG(INFO) << "Started tf.data " << server_type_
65 << " running at 0.0.0.0:" << BoundPort();
66 return Status::OK();
67 }
68
Stop()69 void GrpcDataServerBase::Stop() {
70 if (stopped_) {
71 return;
72 }
73 server_->Shutdown();
74 stopped_ = true;
75 LOG(INFO) << "Shut down " << server_type_ << " server running at port "
76 << BoundPort();
77 }
78
Join()79 void GrpcDataServerBase::Join() { server_->Wait(); }
80
BoundPort()81 int GrpcDataServerBase::BoundPort() { return bound_port(); }
82
AddProfilerServiceToBuilder(::grpc::ServerBuilder & builder)83 void GrpcDataServerBase::AddProfilerServiceToBuilder(
84 ::grpc::ServerBuilder& builder) {
85 profiler_service_ = profiler::CreateProfilerService();
86 builder.RegisterService(profiler_service_.get());
87 }
88
DispatchGrpcDataServer(const experimental::DispatcherConfig & config)89 DispatchGrpcDataServer::DispatchGrpcDataServer(
90 const experimental::DispatcherConfig& config)
91 : GrpcDataServerBase(config.port(), config.protocol(), "DispatchServer"),
92 config_(config) {}
93
~DispatchGrpcDataServer()94 DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
95
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)96 void DispatchGrpcDataServer::AddDataServiceToBuilder(
97 ::grpc::ServerBuilder& builder) {
98 service_ = absl::make_unique<GrpcDispatcherImpl>(config_, builder).release();
99 }
100
StartServiceInternal()101 Status DispatchGrpcDataServer::StartServiceInternal() {
102 return service_->Start();
103 }
104
NumWorkers(int * num_workers)105 Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
106 GetWorkersRequest req;
107 GetWorkersResponse resp;
108 ::grpc::ServerContext ctx;
109 ::grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
110 if (!s.ok()) {
111 return grpc_util::WrapError("Failed to get workers", s);
112 }
113 *num_workers = resp.workers_size();
114 return Status::OK();
115 }
116
WorkerGrpcDataServer(const experimental::WorkerConfig & config)117 WorkerGrpcDataServer::WorkerGrpcDataServer(
118 const experimental::WorkerConfig& config)
119 : GrpcDataServerBase(config.port(), config.protocol(), "WorkerServer"),
120 config_(config) {}
121
~WorkerGrpcDataServer()122 WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
123
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)124 void WorkerGrpcDataServer::AddDataServiceToBuilder(
125 ::grpc::ServerBuilder& builder) {
126 service_ = absl::make_unique<GrpcWorkerImpl>(config_, builder).release();
127 }
128
StartServiceInternal()129 Status WorkerGrpcDataServer::StartServiceInternal() {
130 std::string base_address = config_.worker_address();
131 if (base_address.empty()) {
132 base_address = absl::StrCat("localhost:", kPortPlaceholder);
133 }
134 std::string worker_address = str_util::StringReplace(
135 base_address, kPortPlaceholder, absl::StrCat(bound_port()),
136 /*replace_all=*/false);
137 std::string transfer_address = worker_address;
138 std::string transfer_protocol = config_.data_transfer_protocol();
139 if (!transfer_protocol.empty()) {
140 TF_RETURN_IF_ERROR(DataTransferServer::Build(
141 transfer_protocol, service_->get_element_getter(), &transfer_server_));
142 TF_RETURN_IF_ERROR(transfer_server_->Start());
143 LOG(INFO) << "Data transfer server started at 0.0.0.0:"
144 << transfer_server_->get_port();
145 transfer_address =
146 str_util::StringReplace(base_address, kPortPlaceholder,
147 absl::StrCat(transfer_server_->get_port()),
148 /*replace_all=*/false);
149 }
150 TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_address));
151 return Status::OK();
152 }
153
NumTasks(int * num_tasks)154 Status WorkerGrpcDataServer::NumTasks(int* num_tasks) {
155 GetWorkerTasksRequest req;
156 GetWorkerTasksResponse resp;
157 ::grpc::ServerContext ctx;
158 ::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp);
159 if (!s.ok()) {
160 return grpc_util::WrapError("Failed to get tasks", s);
161 }
162 *num_tasks = resp.tasks_size();
163 return Status::OK();
164 }
165
NewDispatchServer(const experimental::DispatcherConfig & config,std::unique_ptr<DispatchGrpcDataServer> & out_server)166 Status NewDispatchServer(const experimental::DispatcherConfig& config,
167 std::unique_ptr<DispatchGrpcDataServer>& out_server) {
168 out_server = absl::make_unique<DispatchGrpcDataServer>(config);
169 return Status::OK();
170 }
171
NewWorkerServer(const experimental::WorkerConfig & config,std::unique_ptr<WorkerGrpcDataServer> & out_server)172 Status NewWorkerServer(const experimental::WorkerConfig& config,
173 std::unique_ptr<WorkerGrpcDataServer>& out_server) {
174 out_server = absl::make_unique<WorkerGrpcDataServer>(config);
175 return Status::OK();
176 }
177
178 } // namespace data
179 } // namespace tensorflow
180