1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/data_service.h"
17
18 #include "grpcpp/create_channel.h"
19 #include "grpcpp/security/credentials.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/data/service/credentials_factory.h"
22 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
23 #include "tensorflow/core/data/service/grpc_util.h"
24 #include "tensorflow/core/data/service/worker.grpc.pb.h"
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/platform/errors.h"
27
28 namespace tensorflow {
29 namespace data {
30
31 namespace {
32 constexpr const char kParallelEpochs[] = "parallel_epochs";
33 constexpr const char kDistributedEpoch[] = "distributed_epoch";
34
35 } // namespace
36
ParseProcessingMode(const std::string & s,ProcessingMode & mode)37 Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) {
38 if (s == kParallelEpochs) {
39 mode = ProcessingMode::PARALLEL_EPOCHS;
40 } else if (s == kDistributedEpoch) {
41 mode = ProcessingMode::DISTRIBUTED_EPOCH;
42 } else {
43 return errors::InvalidArgument("Unrecognized processing mode: ", s);
44 }
45 return Status::OK();
46 }
47
ProcessingModeToString(ProcessingMode mode)48 std::string ProcessingModeToString(ProcessingMode mode) {
49 switch (mode) {
50 case ProcessingMode::PARALLEL_EPOCHS:
51 return kParallelEpochs;
52 case ProcessingMode::DISTRIBUTED_EPOCH:
53 return kDistributedEpoch;
54 default:
55 DCHECK(false);
56 return "Unknown";
57 }
58 }
59
WorkerHeartbeat(const std::string & worker_address,const std::string & transfer_address,const std::vector<int64> & current_tasks,std::vector<TaskDef> & new_tasks,std::vector<int64> & tasks_to_delete)60 Status DataServiceDispatcherClient::WorkerHeartbeat(
61 const std::string& worker_address, const std::string& transfer_address,
62 const std::vector<int64>& current_tasks, std::vector<TaskDef>& new_tasks,
63 std::vector<int64>& tasks_to_delete) {
64 TF_RETURN_IF_ERROR(EnsureInitialized());
65 WorkerHeartbeatRequest req;
66 req.set_worker_address(worker_address);
67 req.set_transfer_address(transfer_address);
68 for (int64 task : current_tasks) {
69 req.add_current_tasks(task);
70 }
71 WorkerHeartbeatResponse resp;
72 grpc::ClientContext client_ctx;
73 grpc::Status status = stub_->WorkerHeartbeat(&client_ctx, req, &resp);
74 if (!status.ok()) {
75 return grpc_util::WrapError("Failed to perform worker heartbeat", status);
76 }
77 for (const auto& task : resp.new_tasks()) {
78 new_tasks.push_back(task);
79 }
80 for (int64 task_to_delete : resp.tasks_to_delete()) {
81 tasks_to_delete.push_back(task_to_delete);
82 }
83 return Status::OK();
84 }
85
WorkerUpdate(const std::string & worker_address,std::vector<TaskProgress> & task_progress)86 Status DataServiceDispatcherClient::WorkerUpdate(
87 const std::string& worker_address,
88 std::vector<TaskProgress>& task_progress) {
89 TF_RETURN_IF_ERROR(EnsureInitialized());
90 WorkerUpdateRequest req;
91 req.set_worker_address(worker_address);
92 for (const auto& update : task_progress) {
93 *(req.add_updates()) = update;
94 }
95 WorkerUpdateResponse resp;
96 grpc::ClientContext client_ctx;
97 grpc::Status status = stub_->WorkerUpdate(&client_ctx, req, &resp);
98 if (!status.ok()) {
99 return grpc_util::WrapError("Failed to send worker update", status);
100 }
101 return Status::OK();
102 }
103
GetDatasetDef(int64 dataset_id,DatasetDef & dataset_def)104 Status DataServiceDispatcherClient::GetDatasetDef(int64 dataset_id,
105 DatasetDef& dataset_def) {
106 TF_RETURN_IF_ERROR(EnsureInitialized());
107 GetDatasetDefRequest req;
108 req.set_dataset_id(dataset_id);
109 GetDatasetDefResponse resp;
110 grpc::ClientContext client_ctx;
111 grpc::Status status = stub_->GetDatasetDef(&client_ctx, req, &resp);
112 if (!status.ok()) {
113 return grpc_util::WrapError("Failed to get dataset def", status);
114 }
115 dataset_def = resp.dataset_def();
116 return Status::OK();
117 }
118
GetSplit(int64 job_id,int64 repetition,Tensor & split,bool & end_of_splits)119 Status DataServiceDispatcherClient::GetSplit(int64 job_id, int64 repetition,
120 Tensor& split,
121 bool& end_of_splits) {
122 TF_RETURN_IF_ERROR(EnsureInitialized());
123 GetSplitRequest req;
124 req.set_job_id(job_id);
125 req.set_repetition(repetition);
126 GetSplitResponse resp;
127 grpc::ClientContext client_ctx;
128 grpc::Status status = stub_->GetSplit(&client_ctx, req, &resp);
129 if (!status.ok()) {
130 return grpc_util::WrapError("Failed to get split", status);
131 }
132 end_of_splits = resp.end_of_splits();
133 if (!end_of_splits) {
134 if (!split.FromProto(resp.split())) {
135 return errors::Internal("Failed to parse split tensor proto");
136 }
137 }
138 return Status::OK();
139 }
140
RegisterDataset(GraphDef dataset,int64 & dataset_id)141 Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
142 int64& dataset_id) {
143 TF_RETURN_IF_ERROR(EnsureInitialized());
144 GetOrRegisterDatasetRequest req;
145 *req.mutable_dataset()->mutable_graph() = dataset;
146 GetOrRegisterDatasetResponse resp;
147 grpc::ClientContext client_ctx;
148 grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp);
149 if (!status.ok()) {
150 return grpc_util::WrapError("Failed to register dataset", status);
151 }
152 dataset_id = resp.dataset_id();
153 return Status::OK();
154 }
155
GetOrCreateJob(int64 dataset_id,ProcessingMode processing_mode,const absl::optional<JobKey> & job_key,absl::optional<int64> num_consumers,int64 & job_client_id)156 Status DataServiceDispatcherClient::GetOrCreateJob(
157 int64 dataset_id, ProcessingMode processing_mode,
158 const absl::optional<JobKey>& job_key, absl::optional<int64> num_consumers,
159 int64& job_client_id) {
160 TF_RETURN_IF_ERROR(EnsureInitialized());
161 GetOrCreateJobRequest req;
162 req.set_dataset_id(dataset_id);
163 req.set_processing_mode(ProcessingModeDef(processing_mode));
164 if (job_key.has_value()) {
165 *req.mutable_job_key() = job_key.value();
166 }
167 if (num_consumers.has_value()) {
168 req.set_num_consumers(num_consumers.value());
169 }
170 GetOrCreateJobResponse resp;
171 grpc::ClientContext client_ctx;
172 grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
173 if (!status.ok()) {
174 return grpc_util::WrapError(
175 absl::StrCat("Failed to get or create job for dataset with id ",
176 dataset_id),
177 status);
178 }
179 job_client_id = resp.job_client_id();
180 return Status::OK();
181 }
182
ReleaseJobClient(int64 job_client_id)183 Status DataServiceDispatcherClient::ReleaseJobClient(int64 job_client_id) {
184 TF_RETURN_IF_ERROR(EnsureInitialized());
185 ReleaseJobClientRequest req;
186 req.set_job_client_id(job_client_id);
187 ReleaseJobClientResponse resp;
188 grpc::ClientContext client_ctx;
189 grpc::Status status = stub_->ReleaseJobClient(&client_ctx, req, &resp);
190 if (!status.ok()) {
191 return grpc_util::WrapError(
192 absl::StrCat("Failed to release job client with id ", job_client_id),
193 status);
194 }
195 return Status::OK();
196 }
197
ClientHeartbeat(ClientHeartbeatRequest & req,ClientHeartbeatResponse & resp)198 Status DataServiceDispatcherClient::ClientHeartbeat(
199 ClientHeartbeatRequest& req, ClientHeartbeatResponse& resp) {
200 TF_RETURN_IF_ERROR(EnsureInitialized());
201 grpc::ClientContext ctx;
202 grpc::Status s = stub_->ClientHeartbeat(&ctx, req, &resp);
203 if (!s.ok()) {
204 return grpc_util::WrapError("Failed to get tasks", s);
205 }
206 return Status::OK();
207 }
208
GetWorkers(std::vector<WorkerInfo> & workers)209 Status DataServiceDispatcherClient::GetWorkers(
210 std::vector<WorkerInfo>& workers) {
211 TF_RETURN_IF_ERROR(EnsureInitialized());
212 GetWorkersRequest req;
213 GetWorkersResponse resp;
214 grpc::ClientContext ctx;
215 grpc::Status s = stub_->GetWorkers(&ctx, req, &resp);
216 if (!s.ok()) {
217 return grpc_util::WrapError("Failed to get workers", s);
218 }
219 workers.clear();
220 for (auto& worker : resp.workers()) {
221 workers.push_back(worker);
222 }
223 return Status::OK();
224 }
225
EnsureInitialized()226 Status DataServiceDispatcherClient::EnsureInitialized() {
227 mutex_lock l(mu_);
228 if (stub_) {
229 return Status::OK();
230 }
231 std::shared_ptr<grpc::ChannelCredentials> credentials;
232 TF_RETURN_IF_ERROR(
233 CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
234 grpc::ChannelArguments args;
235 args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
236 args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
237 auto channel = grpc::CreateCustomChannel(address_, credentials, args);
238 stub_ = DispatcherService::NewStub(channel);
239 GetVersionRequest req;
240 GetVersionResponse resp;
241 TF_RETURN_IF_ERROR(grpc_util::Retry(
242 [&] {
243 grpc::ClientContext ctx;
244 grpc::Status s = stub_->GetVersion(&ctx, req, &resp);
245 if (!s.ok()) {
246 return grpc_util::WrapError("Failed to get dispatcher version", s);
247 }
248 return Status::OK();
249 },
250 "checking service version",
251 /*deadline_micros=*/kint64max));
252 if (resp.version() != kDataServiceVersion) {
253 return errors::FailedPrecondition(
254 "Version mismatch with tf.data service server. The server is running "
255 "version ",
256 resp.version(), ", while the client is running version ",
257 kDataServiceVersion,
258 ". Please ensure that the client and server side are running the "
259 "same version of TensorFlow.");
260 }
261 return Status::OK();
262 }
263
264 class GrpcDataTransferClient : public DataTransferClient {
265 public:
GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,std::string address)266 GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,
267 std::string address) {
268 grpc::ChannelArguments args;
269 args.SetMaxReceiveMessageSize(-1);
270 auto channel = grpc::CreateCustomChannel(address, credentials, args);
271 stub_ = WorkerService::NewStub(channel);
272 }
273
GetElement(const GetElementRequest & req,GetElementResponse & resp)274 Status GetElement(const GetElementRequest& req,
275 GetElementResponse& resp) override {
276 {
277 mutex_lock l(mu_);
278 if (cancelled_) {
279 return errors::Cancelled("Client was cancelled.");
280 }
281 }
282 grpc::ClientContext ctx;
283 {
284 mutex_lock l(mu_);
285 active_contexts_.insert(&ctx);
286 }
287 grpc::Status s = stub_->GetElement(&ctx, req, &resp);
288 {
289 mutex_lock l(mu_);
290 active_contexts_.erase(&ctx);
291 }
292 if (!s.ok()) {
293 return grpc_util::WrapError("Failed to get element", s);
294 }
295 return Status::OK();
296 }
297
TryCancel()298 void TryCancel() override {
299 mutex_lock l(mu_);
300 cancelled_ = true;
301 for (const auto& ctx : active_contexts_) {
302 ctx->TryCancel();
303 }
304 }
305
306 private:
307 mutex mu_;
308 std::unique_ptr<WorkerService::Stub> stub_;
309 // Set of all currently active clients contexts. Used to support
310 // cancellation.
311 absl::flat_hash_set<::grpc::ClientContext*> active_contexts_
312 TF_GUARDED_BY(mu_);
313 // Indicates that the client has been cancelled, so no further requests should
314 // be accepted.
315 bool cancelled_ TF_GUARDED_BY(mu_) = false;
316 };
317
318 class GrpcTransferClientRegistrar {
319 public:
GrpcTransferClientRegistrar()320 GrpcTransferClientRegistrar() {
321 DataTransferClient::Register(
322 "grpc", [](DataTransferClient::Config config,
323 std::unique_ptr<DataTransferClient>* out) {
324 std::shared_ptr<grpc::ChannelCredentials> credentials;
325 TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
326 config.protocol, &credentials));
327 *out = std::make_unique<GrpcDataTransferClient>(credentials,
328 config.address);
329 return Status::OK();
330 });
331 }
332 };
333 static GrpcTransferClientRegistrar registrar;
334
GetElement(const GetElementRequest & req,GetElementResponse & resp)335 Status DataServiceWorkerClient::GetElement(const GetElementRequest& req,
336 GetElementResponse& resp) {
337 TF_RETURN_IF_ERROR(EnsureInitialized());
338 return client_->GetElement(req, resp);
339 }
340
EnsureInitialized()341 Status DataServiceWorkerClient::EnsureInitialized() {
342 mutex_lock l(mu_);
343 if (client_) {
344 return Status::OK();
345 }
346 TF_RETURN_IF_ERROR(DataTransferClient::Build(
347 transfer_protocol_, {protocol_, address_}, &client_));
348 return Status::OK();
349 }
350
TryCancel()351 void DataServiceWorkerClient::TryCancel() { client_->TryCancel(); }
352
CreateDataServiceDispatcherClient(const std::string & address,const std::string & protocol,std::unique_ptr<DataServiceDispatcherClient> & out)353 Status CreateDataServiceDispatcherClient(
354 const std::string& address, const std::string& protocol,
355 std::unique_ptr<DataServiceDispatcherClient>& out) {
356 auto client =
357 absl::make_unique<DataServiceDispatcherClient>(address, protocol);
358 TF_RETURN_IF_ERROR(client->Initialize());
359 out = std::move(client);
360 return Status::OK();
361 }
362
CreateDataServiceWorkerClient(const std::string & address,const std::string & protocol,const std::string & transfer_protocol,std::unique_ptr<DataServiceWorkerClient> & out)363 Status CreateDataServiceWorkerClient(
364 const std::string& address, const std::string& protocol,
365 const std::string& transfer_protocol,
366 std::unique_ptr<DataServiceWorkerClient>& out) {
367 auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol,
368 transfer_protocol);
369 TF_RETURN_IF_ERROR(client->Initialize());
370 out = std::move(client);
371 return Status::OK();
372 }
373 } // namespace data
374 } // namespace tensorflow
375