1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/data_service.h"
17 
18 #include "grpcpp/create_channel.h"
19 #include "grpcpp/security/credentials.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/core/data/service/credentials_factory.h"
22 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
23 #include "tensorflow/core/data/service/grpc_util.h"
24 #include "tensorflow/core/data/service/worker.grpc.pb.h"
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/platform/errors.h"
27 
28 namespace tensorflow {
29 namespace data {
30 
31 namespace {
32 constexpr const char kParallelEpochs[] = "parallel_epochs";
33 constexpr const char kDistributedEpoch[] = "distributed_epoch";
34 
35 }  // namespace
36 
ParseProcessingMode(const std::string & s,ProcessingMode & mode)37 Status ParseProcessingMode(const std::string& s, ProcessingMode& mode) {
38   if (s == kParallelEpochs) {
39     mode = ProcessingMode::PARALLEL_EPOCHS;
40   } else if (s == kDistributedEpoch) {
41     mode = ProcessingMode::DISTRIBUTED_EPOCH;
42   } else {
43     return errors::InvalidArgument("Unrecognized processing mode: ", s);
44   }
45   return Status::OK();
46 }
47 
ProcessingModeToString(ProcessingMode mode)48 std::string ProcessingModeToString(ProcessingMode mode) {
49   switch (mode) {
50     case ProcessingMode::PARALLEL_EPOCHS:
51       return kParallelEpochs;
52     case ProcessingMode::DISTRIBUTED_EPOCH:
53       return kDistributedEpoch;
54     default:
55       DCHECK(false);
56       return "Unknown";
57   }
58 }
59 
WorkerHeartbeat(const std::string & worker_address,const std::string & transfer_address,const std::vector<int64> & current_tasks,std::vector<TaskDef> & new_tasks,std::vector<int64> & tasks_to_delete)60 Status DataServiceDispatcherClient::WorkerHeartbeat(
61     const std::string& worker_address, const std::string& transfer_address,
62     const std::vector<int64>& current_tasks, std::vector<TaskDef>& new_tasks,
63     std::vector<int64>& tasks_to_delete) {
64   TF_RETURN_IF_ERROR(EnsureInitialized());
65   WorkerHeartbeatRequest req;
66   req.set_worker_address(worker_address);
67   req.set_transfer_address(transfer_address);
68   for (int64 task : current_tasks) {
69     req.add_current_tasks(task);
70   }
71   WorkerHeartbeatResponse resp;
72   grpc::ClientContext client_ctx;
73   grpc::Status status = stub_->WorkerHeartbeat(&client_ctx, req, &resp);
74   if (!status.ok()) {
75     return grpc_util::WrapError("Failed to perform worker heartbeat", status);
76   }
77   for (const auto& task : resp.new_tasks()) {
78     new_tasks.push_back(task);
79   }
80   for (int64 task_to_delete : resp.tasks_to_delete()) {
81     tasks_to_delete.push_back(task_to_delete);
82   }
83   return Status::OK();
84 }
85 
WorkerUpdate(const std::string & worker_address,std::vector<TaskProgress> & task_progress)86 Status DataServiceDispatcherClient::WorkerUpdate(
87     const std::string& worker_address,
88     std::vector<TaskProgress>& task_progress) {
89   TF_RETURN_IF_ERROR(EnsureInitialized());
90   WorkerUpdateRequest req;
91   req.set_worker_address(worker_address);
92   for (const auto& update : task_progress) {
93     *(req.add_updates()) = update;
94   }
95   WorkerUpdateResponse resp;
96   grpc::ClientContext client_ctx;
97   grpc::Status status = stub_->WorkerUpdate(&client_ctx, req, &resp);
98   if (!status.ok()) {
99     return grpc_util::WrapError("Failed to send worker update", status);
100   }
101   return Status::OK();
102 }
103 
GetDatasetDef(int64 dataset_id,DatasetDef & dataset_def)104 Status DataServiceDispatcherClient::GetDatasetDef(int64 dataset_id,
105                                                   DatasetDef& dataset_def) {
106   TF_RETURN_IF_ERROR(EnsureInitialized());
107   GetDatasetDefRequest req;
108   req.set_dataset_id(dataset_id);
109   GetDatasetDefResponse resp;
110   grpc::ClientContext client_ctx;
111   grpc::Status status = stub_->GetDatasetDef(&client_ctx, req, &resp);
112   if (!status.ok()) {
113     return grpc_util::WrapError("Failed to get dataset def", status);
114   }
115   dataset_def = resp.dataset_def();
116   return Status::OK();
117 }
118 
GetSplit(int64 job_id,int64 repetition,Tensor & split,bool & end_of_splits)119 Status DataServiceDispatcherClient::GetSplit(int64 job_id, int64 repetition,
120                                              Tensor& split,
121                                              bool& end_of_splits) {
122   TF_RETURN_IF_ERROR(EnsureInitialized());
123   GetSplitRequest req;
124   req.set_job_id(job_id);
125   req.set_repetition(repetition);
126   GetSplitResponse resp;
127   grpc::ClientContext client_ctx;
128   grpc::Status status = stub_->GetSplit(&client_ctx, req, &resp);
129   if (!status.ok()) {
130     return grpc_util::WrapError("Failed to get split", status);
131   }
132   end_of_splits = resp.end_of_splits();
133   if (!end_of_splits) {
134     if (!split.FromProto(resp.split())) {
135       return errors::Internal("Failed to parse split tensor proto");
136     }
137   }
138   return Status::OK();
139 }
140 
RegisterDataset(GraphDef dataset,int64 & dataset_id)141 Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
142                                                     int64& dataset_id) {
143   TF_RETURN_IF_ERROR(EnsureInitialized());
144   GetOrRegisterDatasetRequest req;
145   *req.mutable_dataset()->mutable_graph() = dataset;
146   GetOrRegisterDatasetResponse resp;
147   grpc::ClientContext client_ctx;
148   grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp);
149   if (!status.ok()) {
150     return grpc_util::WrapError("Failed to register dataset", status);
151   }
152   dataset_id = resp.dataset_id();
153   return Status::OK();
154 }
155 
GetOrCreateJob(int64 dataset_id,ProcessingMode processing_mode,const absl::optional<JobKey> & job_key,absl::optional<int64> num_consumers,int64 & job_client_id)156 Status DataServiceDispatcherClient::GetOrCreateJob(
157     int64 dataset_id, ProcessingMode processing_mode,
158     const absl::optional<JobKey>& job_key, absl::optional<int64> num_consumers,
159     int64& job_client_id) {
160   TF_RETURN_IF_ERROR(EnsureInitialized());
161   GetOrCreateJobRequest req;
162   req.set_dataset_id(dataset_id);
163   req.set_processing_mode(ProcessingModeDef(processing_mode));
164   if (job_key.has_value()) {
165     *req.mutable_job_key() = job_key.value();
166   }
167   if (num_consumers.has_value()) {
168     req.set_num_consumers(num_consumers.value());
169   }
170   GetOrCreateJobResponse resp;
171   grpc::ClientContext client_ctx;
172   grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
173   if (!status.ok()) {
174     return grpc_util::WrapError(
175         absl::StrCat("Failed to get or create job for dataset with id ",
176                      dataset_id),
177         status);
178   }
179   job_client_id = resp.job_client_id();
180   return Status::OK();
181 }
182 
ReleaseJobClient(int64 job_client_id)183 Status DataServiceDispatcherClient::ReleaseJobClient(int64 job_client_id) {
184   TF_RETURN_IF_ERROR(EnsureInitialized());
185   ReleaseJobClientRequest req;
186   req.set_job_client_id(job_client_id);
187   ReleaseJobClientResponse resp;
188   grpc::ClientContext client_ctx;
189   grpc::Status status = stub_->ReleaseJobClient(&client_ctx, req, &resp);
190   if (!status.ok()) {
191     return grpc_util::WrapError(
192         absl::StrCat("Failed to release job client with id ", job_client_id),
193         status);
194   }
195   return Status::OK();
196 }
197 
ClientHeartbeat(ClientHeartbeatRequest & req,ClientHeartbeatResponse & resp)198 Status DataServiceDispatcherClient::ClientHeartbeat(
199     ClientHeartbeatRequest& req, ClientHeartbeatResponse& resp) {
200   TF_RETURN_IF_ERROR(EnsureInitialized());
201   grpc::ClientContext ctx;
202   grpc::Status s = stub_->ClientHeartbeat(&ctx, req, &resp);
203   if (!s.ok()) {
204     return grpc_util::WrapError("Failed to get tasks", s);
205   }
206   return Status::OK();
207 }
208 
GetWorkers(std::vector<WorkerInfo> & workers)209 Status DataServiceDispatcherClient::GetWorkers(
210     std::vector<WorkerInfo>& workers) {
211   TF_RETURN_IF_ERROR(EnsureInitialized());
212   GetWorkersRequest req;
213   GetWorkersResponse resp;
214   grpc::ClientContext ctx;
215   grpc::Status s = stub_->GetWorkers(&ctx, req, &resp);
216   if (!s.ok()) {
217     return grpc_util::WrapError("Failed to get workers", s);
218   }
219   workers.clear();
220   for (auto& worker : resp.workers()) {
221     workers.push_back(worker);
222   }
223   return Status::OK();
224 }
225 
EnsureInitialized()226 Status DataServiceDispatcherClient::EnsureInitialized() {
227   mutex_lock l(mu_);
228   if (stub_) {
229     return Status::OK();
230   }
231   std::shared_ptr<grpc::ChannelCredentials> credentials;
232   TF_RETURN_IF_ERROR(
233       CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
234   grpc::ChannelArguments args;
235   args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
236   args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
237   auto channel = grpc::CreateCustomChannel(address_, credentials, args);
238   stub_ = DispatcherService::NewStub(channel);
239   GetVersionRequest req;
240   GetVersionResponse resp;
241   TF_RETURN_IF_ERROR(grpc_util::Retry(
242       [&] {
243         grpc::ClientContext ctx;
244         grpc::Status s = stub_->GetVersion(&ctx, req, &resp);
245         if (!s.ok()) {
246           return grpc_util::WrapError("Failed to get dispatcher version", s);
247         }
248         return Status::OK();
249       },
250       "checking service version",
251       /*deadline_micros=*/kint64max));
252   if (resp.version() != kDataServiceVersion) {
253     return errors::FailedPrecondition(
254         "Version mismatch with tf.data service server. The server is running "
255         "version ",
256         resp.version(), ", while the client is running version ",
257         kDataServiceVersion,
258         ". Please ensure that the client and server side are running the "
259         "same version of TensorFlow.");
260   }
261   return Status::OK();
262 }
263 
264 class GrpcDataTransferClient : public DataTransferClient {
265  public:
GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,std::string address)266   GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,
267                          std::string address) {
268     grpc::ChannelArguments args;
269     args.SetMaxReceiveMessageSize(-1);
270     auto channel = grpc::CreateCustomChannel(address, credentials, args);
271     stub_ = WorkerService::NewStub(channel);
272   }
273 
GetElement(const GetElementRequest & req,GetElementResponse & resp)274   Status GetElement(const GetElementRequest& req,
275                     GetElementResponse& resp) override {
276     {
277       mutex_lock l(mu_);
278       if (cancelled_) {
279         return errors::Cancelled("Client was cancelled.");
280       }
281     }
282     grpc::ClientContext ctx;
283     {
284       mutex_lock l(mu_);
285       active_contexts_.insert(&ctx);
286     }
287     grpc::Status s = stub_->GetElement(&ctx, req, &resp);
288     {
289       mutex_lock l(mu_);
290       active_contexts_.erase(&ctx);
291     }
292     if (!s.ok()) {
293       return grpc_util::WrapError("Failed to get element", s);
294     }
295     return Status::OK();
296   }
297 
TryCancel()298   void TryCancel() override {
299     mutex_lock l(mu_);
300     cancelled_ = true;
301     for (const auto& ctx : active_contexts_) {
302       ctx->TryCancel();
303     }
304   }
305 
306  private:
307   mutex mu_;
308   std::unique_ptr<WorkerService::Stub> stub_;
309   // Set of all currently active clients contexts. Used to support
310   // cancellation.
311   absl::flat_hash_set<::grpc::ClientContext*> active_contexts_
312       TF_GUARDED_BY(mu_);
313   // Indicates that the client has been cancelled, so no further requests should
314   // be accepted.
315   bool cancelled_ TF_GUARDED_BY(mu_) = false;
316 };
317 
318 class GrpcTransferClientRegistrar {
319  public:
GrpcTransferClientRegistrar()320   GrpcTransferClientRegistrar() {
321     DataTransferClient::Register(
322         "grpc", [](DataTransferClient::Config config,
323                    std::unique_ptr<DataTransferClient>* out) {
324           std::shared_ptr<grpc::ChannelCredentials> credentials;
325           TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
326               config.protocol, &credentials));
327           *out = std::make_unique<GrpcDataTransferClient>(credentials,
328                                                           config.address);
329           return Status::OK();
330         });
331   }
332 };
333 static GrpcTransferClientRegistrar registrar;
334 
GetElement(const GetElementRequest & req,GetElementResponse & resp)335 Status DataServiceWorkerClient::GetElement(const GetElementRequest& req,
336                                            GetElementResponse& resp) {
337   TF_RETURN_IF_ERROR(EnsureInitialized());
338   return client_->GetElement(req, resp);
339 }
340 
EnsureInitialized()341 Status DataServiceWorkerClient::EnsureInitialized() {
342   mutex_lock l(mu_);
343   if (client_) {
344     return Status::OK();
345   }
346   TF_RETURN_IF_ERROR(DataTransferClient::Build(
347       transfer_protocol_, {protocol_, address_}, &client_));
348   return Status::OK();
349 }
350 
TryCancel()351 void DataServiceWorkerClient::TryCancel() { client_->TryCancel(); }
352 
CreateDataServiceDispatcherClient(const std::string & address,const std::string & protocol,std::unique_ptr<DataServiceDispatcherClient> & out)353 Status CreateDataServiceDispatcherClient(
354     const std::string& address, const std::string& protocol,
355     std::unique_ptr<DataServiceDispatcherClient>& out) {
356   auto client =
357       absl::make_unique<DataServiceDispatcherClient>(address, protocol);
358   TF_RETURN_IF_ERROR(client->Initialize());
359   out = std::move(client);
360   return Status::OK();
361 }
362 
CreateDataServiceWorkerClient(const std::string & address,const std::string & protocol,const std::string & transfer_protocol,std::unique_ptr<DataServiceWorkerClient> & out)363 Status CreateDataServiceWorkerClient(
364     const std::string& address, const std::string& protocol,
365     const std::string& transfer_protocol,
366     std::unique_ptr<DataServiceWorkerClient>& out) {
367   auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol,
368                                                            transfer_protocol);
369   TF_RETURN_IF_ERROR(client->Initialize());
370   out = std::move(client);
371   return Status::OK();
372 }
373 }  // namespace data
374 }  // namespace tensorflow
375