1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/dispatcher_impl.h"
17 
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21 
22 #ifdef PLATFORM_GOOGLE
23 #include "file/logging/log_lines.h"
24 #endif
25 #include "grpcpp/create_channel.h"
26 #include "grpcpp/impl/codegen/server_context.h"
27 #include "grpcpp/security/credentials.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/memory/memory.h"
30 #include "tensorflow/core/data/service/common.pb.h"
31 #include "tensorflow/core/data/service/credentials_factory.h"
32 #include "tensorflow/core/data/service/data_service.h"
33 #include "tensorflow/core/data/service/dataset_store.h"
34 #include "tensorflow/core/data/service/dispatcher.pb.h"
35 #include "tensorflow/core/data/service/grpc_util.h"
36 #include "tensorflow/core/data/service/journal.h"
37 #include "tensorflow/core/data/service/worker.grpc.pb.h"
38 #include "tensorflow/core/data/standalone.h"
39 #include "tensorflow/core/framework/tensor.pb.h"
40 #include "tensorflow/core/kernels/data/dataset_utils.h"
41 #include "tensorflow/core/kernels/data/hash_utils.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/path.h"
45 #include "tensorflow/core/protobuf/service_config.pb.h"
46 #include "tensorflow/core/public/session_options.h"
47 
48 namespace tensorflow {
49 namespace data {
50 
51 namespace {
52 // The name of the journal directory inside the dispatcher's working directory.
53 // This name is load-bearing; do not change.
54 constexpr char kJournalDir[] = "tf_data_dispatcher_journal";
55 // The name of the datasets directory inside the dispatcher's working directory.
56 constexpr char kDatasetsDir[] = "datasets";
57 
58 constexpr std::array<const char*, 8> kNodeNameSharingOps = {
59     "HashTable",
60     "HashTableV2",
61     "MutableHashTable",
62     "MutableHashTableV2",
63     "MutableDenseHashTable",
64     "MutableDenseHashTableV2",
65     "MutableHashTableOfTensors",
66     "MutableHashTableOfTensorsV2",
67 };
68 
69 using Dataset = DispatcherState::Dataset;
70 using Worker = DispatcherState::Worker;
71 using NamedJobKey = DispatcherState::NamedJobKey;
72 using Job = DispatcherState::Job;
73 using Task = DispatcherState::Task;
74 
JournalDir(const std::string & work_dir)75 std::string JournalDir(const std::string& work_dir) {
76   return io::JoinPath(work_dir, kJournalDir);
77 }
78 
DatasetsDir(const std::string & work_dir)79 std::string DatasetsDir(const std::string& work_dir) {
80   return io::JoinPath(work_dir, kDatasetsDir);
81 }
82 
DatasetKey(int64 id,uint64 fingerprint)83 std::string DatasetKey(int64 id, uint64 fingerprint) {
84   return absl::StrCat("id_", id, "_fp_", fingerprint);
85 }
86 
CreateWorkerStub(const std::string & address,const std::string & protocol,std::unique_ptr<WorkerService::Stub> & stub)87 Status CreateWorkerStub(const std::string& address, const std::string& protocol,
88                         std::unique_ptr<WorkerService::Stub>& stub) {
89   ::grpc::ChannelArguments args;
90   args.SetMaxReceiveMessageSize(-1);
91   std::shared_ptr<::grpc::ChannelCredentials> credentials;
92   TF_RETURN_IF_ERROR(
93       CredentialsFactory::CreateClientCredentials(protocol, &credentials));
94   auto channel = ::grpc::CreateCustomChannel(address, credentials, args);
95   stub = WorkerService::NewStub(channel);
96   return Status::OK();
97 }
98 
PrepareGraph(GraphDef * graph)99 void PrepareGraph(GraphDef* graph) {
100   for (NodeDef& node : *graph->mutable_node()) {
101     for (const auto& op : kNodeNameSharingOps) {
102       // Set `use_node_name_sharing` to `true` so that resources aren't deleted
103       // prematurely. Otherwise, resources may be deleted when their ops are
104       // deleted at the end of the GraphRunner::Run used by standalone::Dataset.
105       if (node.op() == op) {
106         (*node.mutable_attr())["use_node_name_sharing"].set_b(true);
107       }
108       if (!node.device().empty()) {
109         *node.mutable_device() = "";
110       }
111     }
112   }
113   StripDevicePlacement(graph->mutable_library());
114 }
115 }  // namespace
116 
DataServiceDispatcherImpl(const experimental::DispatcherConfig & config)117 DataServiceDispatcherImpl::DataServiceDispatcherImpl(
118     const experimental::DispatcherConfig& config)
119     : config_(config), env_(Env::Default()) {
120   if (config_.work_dir().empty()) {
121     dataset_store_ = absl::make_unique<MemoryDatasetStore>();
122   } else {
123     dataset_store_ = absl::make_unique<FileSystemDatasetStore>(
124         DatasetsDir(config_.work_dir()));
125   }
126 }
127 
~DataServiceDispatcherImpl()128 DataServiceDispatcherImpl::~DataServiceDispatcherImpl() {
129   {
130     mutex_lock l(mu_);
131     cancelled_ = true;
132     job_gc_thread_cv_.notify_all();
133   }
134   job_gc_thread_.reset();
135 }
136 
Start()137 Status DataServiceDispatcherImpl::Start() {
138   mutex_lock l(mu_);
139   job_gc_thread_ = absl::WrapUnique(
140       env_->StartThread({}, "job-gc-thread", [&] { JobGcThread(); }));
141   if (config_.work_dir().empty()) {
142     if (config_.fault_tolerant_mode()) {
143       return errors::InvalidArgument(
144           "fault_tolerant_mode is True, but no work_dir is configured.");
145     }
146   } else {
147     TF_RETURN_IF_ERROR(
148         env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
149   }
150   if (!config_.fault_tolerant_mode()) {
151     LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
152                  "not be able to recover its state on restart.";
153     started_ = true;
154     return Status::OK();
155   }
156   journal_writer_ = absl::make_unique<FileJournalWriter>(
157       env_, JournalDir(config_.work_dir()));
158   LOG(INFO) << "Attempting to restore dispatcher state from journal in "
159             << JournalDir(config_.work_dir());
160   Update update;
161   bool end_of_journal = false;
162   FileJournalReader reader(env_, JournalDir(config_.work_dir()));
163   Status s = reader.Read(update, end_of_journal);
164   if (errors::IsNotFound(s)) {
165     LOG(INFO) << "No journal found. Starting dispatcher from new state.";
166   } else if (!s.ok()) {
167     return s;
168   } else {
169     while (!end_of_journal) {
170       TF_RETURN_IF_ERROR(ApplyWithoutJournaling(update));
171       TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal));
172     }
173   }
174   for (const auto& job : state_.ListJobs()) {
175     if (job->processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
176       TF_RETURN_IF_ERROR(
177           RestoreSplitProvider(*job, split_providers_[job->job_id]));
178     }
179   }
180   // Initialize the journal writer in `Start` so that we fail fast in case it
181   // can't be initialized.
182   TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized());
183   started_ = true;
184   return Status::OK();
185 }
186 
RestoreSplitProvider(const Job & job,std::unique_ptr<SplitProvider> & restored)187 Status DataServiceDispatcherImpl::RestoreSplitProvider(
188     const Job& job, std::unique_ptr<SplitProvider>& restored)
189     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
190   int64 index = job.distributed_epoch_state.value().split_provider_index;
191   VLOG(1) << "Restoring split provider for job " << job.job_id << " to index "
192           << index;
193   std::unique_ptr<SplitProvider> split_provider;
194   TF_RETURN_IF_ERROR(MakeSplitProvider(job.dataset_id, split_provider));
195   Tensor unused_tensor;
196   bool unused_end_of_splits;
197   for (int i = 0; i < index; ++i) {
198     TF_RETURN_IF_ERROR(
199         split_provider->GetNext(&unused_tensor, &unused_end_of_splits));
200   }
201   restored = std::move(split_provider);
202   return Status::OK();
203 }
204 
WorkerHeartbeat(const WorkerHeartbeatRequest * request,WorkerHeartbeatResponse * response)205 Status DataServiceDispatcherImpl::WorkerHeartbeat(
206     const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
207   TF_RETURN_IF_ERROR(CheckStarted());
208   VLOG(4) << "Received worker heartbeat request from worker "
209           << request->worker_address();
210   mutex_lock l(mu_);
211   const std::string& worker_address = request->worker_address();
212   std::vector<std::shared_ptr<const Task>> correct_tasks;
213   Status s = state_.TasksForWorker(worker_address, correct_tasks);
214   if (!s.ok()) {
215     if (!errors::IsNotFound(s)) {
216       return s;
217     }
218     VLOG(1) << "Registering new worker at address " << worker_address;
219     Update update;
220     update.mutable_register_worker()->set_worker_address(worker_address);
221     update.mutable_register_worker()->set_transfer_address(
222         request->transfer_address());
223     TF_RETURN_IF_ERROR(Apply(update));
224     TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
225     TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, correct_tasks));
226   }
227 
228   absl::flat_hash_set<int64> current_tasks;
229   current_tasks.insert(request->current_tasks().cbegin(),
230                        request->current_tasks().cend());
231   absl::flat_hash_set<int64> correct_tasks_set;
232 
233   for (const auto& task : correct_tasks) {
234     correct_tasks_set.insert(task->task_id);
235     if (current_tasks.contains(task->task_id)) {
236       continue;
237     }
238     TaskDef* task_def = response->add_new_tasks();
239     std::shared_ptr<const Dataset> dataset;
240     TF_RETURN_IF_ERROR(state_.DatasetFromId(task->job->dataset_id, dataset));
241     std::string dataset_key =
242         DatasetKey(dataset->dataset_id, dataset->fingerprint);
243     if (config_.work_dir().empty()) {
244       std::shared_ptr<const DatasetDef> dataset_def;
245       TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
246       *task_def->mutable_dataset_def() = *dataset_def;
247     } else {
248       std::string path =
249           io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
250       task_def->set_path(path);
251     }
252     task_def->set_dataset_id(task->job->dataset_id);
253     task_def->set_job_id(task->job->job_id);
254     task_def->set_task_id(task->task_id);
255     task_def->set_processing_mode(
256         ProcessingModeDef(task->job->processing_mode));
257     if (task->job->num_consumers.has_value()) {
258       task_def->set_num_consumers(task->job->num_consumers.value());
259     }
260   }
261   for (int64 current_task : current_tasks) {
262     if (!correct_tasks_set.contains(current_task)) {
263       response->add_tasks_to_delete(current_task);
264     }
265   }
266 
267   VLOG(4) << "Finished worker heartbeat for worker at address "
268           << request->worker_address();
269   return Status::OK();
270 }
271 
WorkerUpdate(const WorkerUpdateRequest * request,WorkerUpdateResponse * response)272 Status DataServiceDispatcherImpl::WorkerUpdate(
273     const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
274   TF_RETURN_IF_ERROR(CheckStarted());
275   mutex_lock l(mu_);
276   for (auto& update : request->updates()) {
277     int64 task_id = update.task_id();
278     std::shared_ptr<const Task> task;
279     TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
280     if (update.completed()) {
281       if (task->finished) {
282         VLOG(1) << "Received completion update for already-finished task "
283                 << task->task_id << " on worker " << task->worker_address;
284         continue;
285       }
286       Update update;
287       update.mutable_finish_task()->set_task_id(task_id);
288       TF_RETURN_IF_ERROR(Apply(update));
289       VLOG(3) << "Task " << task_id << " from job " << task->job->job_id
290               << " completed";
291     }
292   }
293   return Status::OK();
294 }
295 
GetDatasetDef(const GetDatasetDefRequest * request,GetDatasetDefResponse * response)296 Status DataServiceDispatcherImpl::GetDatasetDef(
297     const GetDatasetDefRequest* request, GetDatasetDefResponse* response) {
298   TF_RETURN_IF_ERROR(CheckStarted());
299   mutex_lock l(mu_);
300   std::shared_ptr<const Dataset> dataset;
301   TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), dataset));
302   std::shared_ptr<const DatasetDef> dataset_def;
303   TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
304   *response->mutable_dataset_def() = *dataset_def;
305   return Status::OK();
306 }
307 
GetSplit(const GetSplitRequest * request,GetSplitResponse * response)308 Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request,
309                                            GetSplitResponse* response) {
310   TF_RETURN_IF_ERROR(CheckStarted());
311   mutex_lock l(mu_);
312   int64 job_id = request->job_id();
313   int64 repetition = request->repetition();
314   VLOG(3) << "Received GetSplit request for job " << job_id << ", repetition "
315           << repetition;
316   std::shared_ptr<const Job> job;
317   TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
318   if (!job->distributed_epoch_state.has_value()) {
319     return errors::FailedPrecondition(
320         "Cannot get split for job ", job_id,
321         ", since it is not a distributed_epoch job.");
322   }
323   int64 current_repetition = job->distributed_epoch_state.value().repetition;
324   if (repetition < current_repetition) {
325     response->set_end_of_splits(true);
326     VLOG(3) << "Returning end_of_splits since current reptition "
327             << current_repetition << " is greater than the requested reptition "
328             << repetition;
329     return Status::OK();
330   }
331   SplitProvider* split_provider = split_providers_[job_id].get();
332   DCHECK(split_provider != nullptr);
333   Tensor split;
334   bool end_of_splits = false;
335   TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
336   TF_RETURN_IF_ERROR(RecordSplitProduced(job_id, repetition, end_of_splits));
337   response->set_end_of_splits(end_of_splits);
338   if (end_of_splits) {
339     // Create a new split provider for the next repetition.
340     TF_RETURN_IF_ERROR(
341         MakeSplitProvider(job->dataset_id, split_providers_[job_id]));
342   } else {
343     split.AsProtoTensorContent(response->mutable_split());
344   }
345   VLOG(3) << "Returning from GetSplit, end_of_splits=" << end_of_splits;
346   return Status::OK();
347 }
348 
MakeSplitProvider(int64 dataset_id,std::unique_ptr<SplitProvider> & split_provider)349 Status DataServiceDispatcherImpl::MakeSplitProvider(
350     int64 dataset_id, std::unique_ptr<SplitProvider>& split_provider)
351     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
352   std::shared_ptr<const Dataset> dataset;
353   TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
354   std::shared_ptr<const DatasetDef> dataset_def;
355   TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
356   standalone::Dataset::Params params;
357   std::unique_ptr<standalone::Dataset> standalone_dataset;
358   TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
359       params, dataset_def->graph(), &standalone_dataset));
360   TF_RETURN_IF_ERROR(standalone_dataset->MakeSplitProvider(&split_provider));
361   return Status::OK();
362 }
363 
GetVersion(const GetVersionRequest * request,GetVersionResponse * response)364 Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request,
365                                              GetVersionResponse* response) {
366   response->set_version(kDataServiceVersion);
367   return Status::OK();
368 }
369 
GetOrRegisterDataset(const GetOrRegisterDatasetRequest * request,GetOrRegisterDatasetResponse * response)370 Status DataServiceDispatcherImpl::GetOrRegisterDataset(
371     const GetOrRegisterDatasetRequest* request,
372     GetOrRegisterDatasetResponse* response) {
373   TF_RETURN_IF_ERROR(CheckStarted());
374   uint64 fingerprint;
375   DatasetDef dataset_def = request->dataset();
376   GraphDef* graph = dataset_def.mutable_graph();
377   PrepareGraph(graph);
378   TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
379 
380   mutex_lock l(mu_);
381 #if defined(PLATFORM_GOOGLE)
382   VLOG_LINES(4,
383              absl::StrCat("Registering dataset graph: ", graph->DebugString()));
384 #else
385   VLOG(4) << "Registering dataset graph: " << graph->DebugString();
386 #endif
387   std::shared_ptr<const Dataset> dataset;
388   Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
389   if (s.ok()) {
390     int64 id = dataset->dataset_id;
391     VLOG(3) << "Received duplicate RegisterDataset request with fingerprint "
392             << fingerprint << ". Returning id " << id;
393     response->set_dataset_id(id);
394     return Status::OK();
395   } else if (!errors::IsNotFound(s)) {
396     return s;
397   }
398 
399   int64 id;
400   TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def, id));
401   response->set_dataset_id(id);
402   VLOG(3) << "Registered new dataset with id " << id;
403   return Status::OK();
404 }
405 
RegisterDataset(uint64 fingerprint,const DatasetDef & dataset,int64 & dataset_id)406 Status DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
407                                                   const DatasetDef& dataset,
408                                                   int64& dataset_id)
409     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
410   dataset_id = state_.NextAvailableDatasetId();
411   Update update;
412   RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
413   register_dataset->set_dataset_id(dataset_id);
414   register_dataset->set_fingerprint(fingerprint);
415   TF_RETURN_IF_ERROR(
416       dataset_store_->Put(DatasetKey(dataset_id, fingerprint), dataset));
417   return Apply(update);
418 }
419 
GetOrCreateJob(const GetOrCreateJobRequest * request,GetOrCreateJobResponse * response)420 Status DataServiceDispatcherImpl::GetOrCreateJob(
421     const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
422   TF_RETURN_IF_ERROR(CheckStarted());
423   VLOG(3) << "GetOrCreateJob(" << request->DebugString() << ")";
424   absl::optional<NamedJobKey> key;
425   if (request->has_job_key()) {
426     key.emplace(request->job_key().job_name(),
427                 request->job_key().job_name_index());
428   }
429   ProcessingMode requested_processing_mode =
430       ProcessingMode(request->processing_mode());
431   std::shared_ptr<const Job> job;
432   std::vector<std::shared_ptr<const Task>> tasks;
433   {
434     mutex_lock l(mu_);
435     if (key.has_value()) {
436       Status s = state_.NamedJobByKey(key.value(), job);
437       if (s.ok()) {
438         TF_RETURN_IF_ERROR(ValidateMatchingJob(job, requested_processing_mode,
439                                                request->dataset_id()));
440         int64 job_client_id;
441         TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
442         response->set_job_client_id(job_client_id);
443         VLOG(3) << "Found existing job for name=" << key.value().name
444                 << ", index=" << key.value().index
445                 << ". job_id: " << job->job_id;
446         return Status::OK();
447       } else if (!errors::IsNotFound(s)) {
448         return s;
449       }
450     }
451     absl::optional<int64> num_consumers;
452     if (request->optional_num_consumers_case() ==
453         GetOrCreateJobRequest::kNumConsumers) {
454       num_consumers = request->num_consumers();
455     }
456     TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(),
457                                  requested_processing_mode, key, num_consumers,
458                                  job));
459     int64 job_client_id;
460     TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
461     response->set_job_client_id(job_client_id);
462     TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks));
463   }
464   TF_RETURN_IF_ERROR(AssignTasks(tasks));
465   VLOG(3) << "Created job " << job->job_id << " for CreateJob("
466           << request->DebugString() << ")";
467   return Status::OK();
468 }
469 
ReleaseJobClient(const ReleaseJobClientRequest * request,ReleaseJobClientResponse * response)470 Status DataServiceDispatcherImpl::ReleaseJobClient(
471     const ReleaseJobClientRequest* request,
472     ReleaseJobClientResponse* response) {
473   TF_RETURN_IF_ERROR(CheckStarted());
474   mutex_lock l(mu_);
475   int64 job_client_id = request->job_client_id();
476   std::shared_ptr<const Job> job;
477   TF_RETURN_IF_ERROR(state_.JobForJobClientId(job_client_id, job));
478   Update update;
479   ReleaseJobClientUpdate* release_job_client =
480       update.mutable_release_job_client();
481   release_job_client->set_job_client_id(job_client_id);
482   release_job_client->set_time_micros(env_->NowMicros());
483   TF_RETURN_IF_ERROR(Apply(update));
484   return Status::OK();
485 }
486 
487 // Validates that the job matches the given processing_mode and dataset_id.
ValidateMatchingJob(std::shared_ptr<const Job> job,ProcessingMode processing_mode,int64 dataset_id)488 Status DataServiceDispatcherImpl::ValidateMatchingJob(
489     std::shared_ptr<const Job> job, ProcessingMode processing_mode,
490     int64 dataset_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
491   DCHECK(job->named_job_key.has_value());
492   std::string job_name = job->named_job_key->name;
493   if (job->processing_mode != processing_mode) {
494     std::string requested = ProcessingModeToString(processing_mode);
495     std::string actual = ProcessingModeToString(job->processing_mode);
496     return errors::FailedPrecondition(
497         "Tried to create a job with name ", job_name, " and processing_mode <",
498         requested,
499         "> but there is already an existing job with that name using "
500         "processing mode <",
501         actual, ">");
502   }
503   return Status::OK();
504 }
505 
CreateJob(int64 dataset_id,ProcessingMode processing_mode,absl::optional<NamedJobKey> named_job_key,absl::optional<int64> num_consumers,std::shared_ptr<const Job> & job)506 Status DataServiceDispatcherImpl::CreateJob(
507     int64 dataset_id, ProcessingMode processing_mode,
508     absl::optional<NamedJobKey> named_job_key,
509     absl::optional<int64> num_consumers, std::shared_ptr<const Job>& job)
510     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
511   switch (processing_mode) {
512     case ProcessingMode::PARALLEL_EPOCHS:
513     case ProcessingMode::DISTRIBUTED_EPOCH:
514       break;
515     default:
516       return errors::Internal(
517           absl::StrCat("ProcessingMode ", processing_mode, " not recognized"));
518   }
519   int64 job_id = state_.NextAvailableJobId();
520   if (processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
521     TF_RETURN_IF_ERROR(MakeSplitProvider(dataset_id, split_providers_[job_id]));
522   }
523   Update update;
524   CreateJobUpdate* create_job = update.mutable_create_job();
525   create_job->set_job_id(job_id);
526   create_job->set_dataset_id(dataset_id);
527   create_job->set_processing_mode(ProcessingModeDef(processing_mode));
528   if (named_job_key.has_value()) {
529     NamedJobKeyDef* key = create_job->mutable_named_job_key();
530     key->set_name(named_job_key->name);
531     key->set_index(named_job_key->index);
532   }
533   if (num_consumers.has_value()) {
534     create_job->set_num_consumers(num_consumers.value());
535   }
536   TF_RETURN_IF_ERROR(Apply(update));
537   TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
538   return Status::OK();
539 }
540 
CreateTasksForWorker(const std::string & worker_address)541 Status DataServiceDispatcherImpl::CreateTasksForWorker(
542     const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
543   std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
544   for (const auto& job : jobs) {
545     if (job->finished) {
546       continue;
547     }
548     if (job->num_consumers.has_value()) {
549       TF_RETURN_IF_ERROR(CreatePendingTask(job, worker_address));
550       continue;
551     }
552     std::shared_ptr<const Task> task;
553     TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task));
554   }
555   return Status::OK();
556 }
557 
AcquireJobClientId(const std::shared_ptr<const Job> & job,int64 & job_client_id)558 Status DataServiceDispatcherImpl::AcquireJobClientId(
559     const std::shared_ptr<const Job>& job, int64& job_client_id)
560     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
561   job_client_id = state_.NextAvailableJobClientId();
562   Update update;
563   AcquireJobClientUpdate* acquire_job_client =
564       update.mutable_acquire_job_client();
565   acquire_job_client->set_job_client_id(job_client_id);
566   acquire_job_client->set_job_id(job->job_id);
567   TF_RETURN_IF_ERROR(Apply(update));
568   return Status::OK();
569 }
570 
CreateTasksForJob(std::shared_ptr<const Job> job,std::vector<std::shared_ptr<const Task>> & tasks)571 Status DataServiceDispatcherImpl::CreateTasksForJob(
572     std::shared_ptr<const Job> job,
573     std::vector<std::shared_ptr<const Task>>& tasks)
574     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
575   std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
576   tasks.clear();
577   tasks.reserve(workers.size());
578   for (const auto& worker : workers) {
579     std::shared_ptr<const Task> task;
580     TF_RETURN_IF_ERROR(CreateTask(job, worker->address, task));
581     tasks.push_back(task);
582   }
583   return Status::OK();
584 }
585 
CreatePendingTask(std::shared_ptr<const Job> job,const std::string & worker_address)586 Status DataServiceDispatcherImpl::CreatePendingTask(
587     std::shared_ptr<const Job> job, const std::string& worker_address)
588     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
589   int64 task_id = state_.NextAvailableTaskId();
590   Update update;
591   CreatePendingTaskUpdate* create_task = update.mutable_create_pending_task();
592   create_task->set_task_id(task_id);
593   create_task->set_job_id(job->job_id);
594   create_task->set_worker_address(worker_address);
595   create_task->set_starting_round(round_robin_rounds_[job->job_id] + 1);
596   std::shared_ptr<const Worker> worker;
597   TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
598   create_task->set_transfer_address(worker->transfer_address);
599   TF_RETURN_IF_ERROR(Apply(update));
600   return Status::OK();
601 }
602 
CreateTask(std::shared_ptr<const Job> job,const std::string & worker_address,std::shared_ptr<const Task> & task)603 Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr<const Job> job,
604                                              const std::string& worker_address,
605                                              std::shared_ptr<const Task>& task)
606     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
607   int64 task_id = state_.NextAvailableTaskId();
608   Update update;
609   CreateTaskUpdate* create_task = update.mutable_create_task();
610   create_task->set_task_id(task_id);
611   create_task->set_job_id(job->job_id);
612   create_task->set_worker_address(worker_address);
613   std::shared_ptr<const Worker> worker;
614   TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
615   create_task->set_transfer_address(worker->transfer_address);
616   TF_RETURN_IF_ERROR(Apply(update));
617   TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
618   return Status::OK();
619 }
620 
AssignTasks(std::vector<std::shared_ptr<const Task>> tasks)621 Status DataServiceDispatcherImpl::AssignTasks(
622     std::vector<std::shared_ptr<const Task>> tasks) TF_LOCKS_EXCLUDED(mu_) {
623   for (const auto& task : tasks) {
624     TF_RETURN_IF_ERROR(AssignTask(task));
625   }
626   return Status::OK();
627 }
628 
GetOrCreateWorkerStub(const std::string & worker_address,WorkerService::Stub * & out_stub)629 Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
630     const std::string& worker_address, WorkerService::Stub*& out_stub)
631     TF_LOCKS_EXCLUDED(mu_) {
632   {
633     mutex_lock l(mu_);
634     auto it = worker_stubs_.find(worker_address);
635     if (it != worker_stubs_.end()) {
636       out_stub = it->second.get();
637       return Status::OK();
638     }
639   }
640   std::unique_ptr<WorkerService::Stub> stub;
641   TF_RETURN_IF_ERROR(
642       CreateWorkerStub(worker_address, config_.protocol(), stub));
643   {
644     mutex_lock l(mu_);
645     // A concurrent call could have already created the stub.
646     auto& worker = worker_stubs_[worker_address];
647     if (worker == nullptr) {
648       worker = std::move(stub);
649     }
650     out_stub = worker.get();
651   }
652   return Status::OK();
653 }
654 
AssignTask(std::shared_ptr<const Task> task)655 Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
656     TF_LOCKS_EXCLUDED(mu_) {
657   VLOG(2) << "Started assigning task " << task->task_id << " to worker "
658           << task->worker_address;
659   grpc::ClientContext client_ctx;
660   ProcessTaskRequest req;
661   TaskDef* task_def = req.mutable_task();
662   task_def->set_dataset_id(task->job->dataset_id);
663   task_def->set_job_id(task->job->job_id);
664   {
665     mutex_lock l(mu_);
666     std::shared_ptr<const Dataset> dataset;
667     TF_RETURN_IF_ERROR(state_.DatasetFromId(task->job->dataset_id, dataset));
668     std::string dataset_key =
669         DatasetKey(dataset->dataset_id, dataset->fingerprint);
670     if (config_.work_dir().empty()) {
671       std::shared_ptr<const DatasetDef> dataset_def;
672       TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
673       *task_def->mutable_dataset_def() = *dataset_def;
674     } else {
675       std::string path =
676           io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
677       task_def->set_path(path);
678     }
679   }
680   task_def->set_task_id(task->task_id);
681   task_def->set_processing_mode(ProcessingModeDef(task->job->processing_mode));
682   if (task->job->num_consumers.has_value()) {
683     task_def->set_num_consumers(task->job->num_consumers.value());
684   }
685   ProcessTaskResponse resp;
686   WorkerService::Stub* stub;
687   TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));
688   grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp);
689   if (!s.ok()) {
690     return grpc_util::WrapError(
691         absl::StrCat("Failed to submit task to worker ", task->worker_address),
692         s);
693   }
694   VLOG(2) << "Finished assigning task " << task->task_id << " to worker "
695           << task->worker_address;
696   return Status::OK();
697 }
698 
ClientHeartbeat(const ClientHeartbeatRequest * request,ClientHeartbeatResponse * response)699 Status DataServiceDispatcherImpl::ClientHeartbeat(
700     const ClientHeartbeatRequest* request, ClientHeartbeatResponse* response) {
701   TF_RETURN_IF_ERROR(CheckStarted());
702   mutex_lock l(mu_);
703   VLOG(4) << "Received heartbeat from client id " << request->job_client_id();
704   std::shared_ptr<const Job> job;
705   Status s = state_.JobForJobClientId(request->job_client_id(), job);
706   if (errors::IsNotFound(s) && !config_.fault_tolerant_mode()) {
707     return errors::NotFound(
708         "Unknown job client id ", request->job_client_id(),
709         ". The dispatcher is not configured to be fault tolerant, so this "
710         "could be caused by a dispatcher restart.");
711   }
712   TF_RETURN_IF_ERROR(s);
713   if (request->optional_current_round_case() ==
714       ClientHeartbeatRequest::kCurrentRound) {
715     round_robin_rounds_[request->job_client_id()] =
716         std::max(round_robin_rounds_[request->job_client_id()],
717                  request->current_round());
718   }
719   if (!job->pending_tasks.empty()) {
720     const auto& task = job->pending_tasks.front();
721     Update update;
722     ClientHeartbeatUpdate* client_heartbeat = update.mutable_client_heartbeat();
723     bool apply_update = false;
724     client_heartbeat->set_job_client_id(request->job_client_id());
725     absl::optional<int64> blocked_round;
726     if (request->optional_blocked_round_case() ==
727         ClientHeartbeatRequest::kBlockedRound) {
728       blocked_round = request->blocked_round();
729     }
730     VLOG(1) << "Handling pending task in job client heartbeat. job_client_id: "
731             << request->job_client_id()
732             << ". current_round: " << request->current_round()
733             << ". blocked_round: " << blocked_round.value_or(-1)
734             << ". target_round: " << task.target_round;
735     if (request->current_round() >= task.target_round) {
736       TaskRejected* rejected = client_heartbeat->mutable_task_rejected();
737       // Exponentially try later and later rounds until consumers all agree.
738       int64 round_offset = 2;
739       for (int i = 0; i < task.failures; ++i) {
740         round_offset *= 2;
741       }
742       rejected->set_new_target_round(
743           round_robin_rounds_[request->job_client_id()] + round_offset);
744       apply_update = true;
745     }
746     if (blocked_round.has_value() &&
747         blocked_round.value() <= task.target_round &&
748         !task.ready_consumers.contains(request->job_client_id())) {
749       client_heartbeat->set_task_accepted(true);
750       apply_update = true;
751     }
752     if (apply_update) {
753       TF_RETURN_IF_ERROR(Apply(update));
754     }
755   }
756   if (!job->pending_tasks.empty()) {
757     response->set_block_round(job->pending_tasks.front().target_round);
758   }
759 
760   std::vector<std::shared_ptr<const Task>> tasks;
761   TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks));
762   for (const auto& task : tasks) {
763     TaskInfo* task_info = response->mutable_task_info()->Add();
764     task_info->set_worker_address(task->worker_address);
765     task_info->set_transfer_address(task->transfer_address);
766     task_info->set_task_id(task->task_id);
767     task_info->set_job_id(job->job_id);
768     task_info->set_starting_round(task->starting_round);
769   }
770   response->set_job_finished(job->finished);
771   VLOG(4) << "Found " << response->task_info_size()
772           << " tasks for job client id " << request->job_client_id();
773   return Status::OK();
774 }
775 
GetWorkers(const GetWorkersRequest * request,GetWorkersResponse * response)776 Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
777                                              GetWorkersResponse* response) {
778   TF_RETURN_IF_ERROR(CheckStarted());
779   mutex_lock l(mu_);
780   VLOG(3) << "Enter GetWorkers";
781   std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
782   for (const auto& worker : workers) {
783     WorkerInfo* info = response->add_workers();
784     info->set_address(worker->address);
785   }
786   VLOG(3) << "Returning list of " << response->workers_size()
787           << " workers from GetWorkers";
788   return Status::OK();
789 }
790 
CheckStarted()791 Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) {
792   mutex_lock l(mu_);
793   if (!started_) {
794     return errors::Unavailable("Dispatcher has not started yet.");
795   }
796   return Status::OK();
797 }
798 
RecordSplitProduced(int64 job_id,int64 repetition,bool finished)799 Status DataServiceDispatcherImpl::RecordSplitProduced(int64 job_id,
800                                                       int64 repetition,
801                                                       bool finished)
802     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
803   Update update;
804   ProduceSplitUpdate* produce_split = update.mutable_produce_split();
805   produce_split->set_job_id(job_id);
806   produce_split->set_repetition(repetition);
807   produce_split->set_finished(finished);
808   return Apply(update);
809 }
810 
ApplyWithoutJournaling(const Update & update)811 Status DataServiceDispatcherImpl::ApplyWithoutJournaling(const Update& update)
812     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
813   return state_.Apply(update);
814 }
815 
Apply(const Update & update)816 Status DataServiceDispatcherImpl::Apply(const Update& update)
817     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
818   if (journal_writer_.has_value()) {
819     TF_RETURN_IF_ERROR(journal_writer_.value()->Write(update));
820   }
821   return state_.Apply(update);
822 }
823 
JobGcThread()824 void DataServiceDispatcherImpl::JobGcThread() {
825   int64 next_check_micros = 0;
826   while (true) {
827     mutex_lock l(mu_);
828     while (!cancelled_ && env_->NowMicros() < next_check_micros) {
829       int64 remaining_micros = next_check_micros - env_->NowMicros();
830       job_gc_thread_cv_.wait_for(l,
831                                  std::chrono::microseconds(remaining_micros));
832     }
833     if (cancelled_) {
834       return;
835     }
836     Status s = GcOldJobs();
837     if (!s.ok()) {
838       LOG(WARNING) << "Error garbage collecting old jobs: " << s;
839     }
840     next_check_micros =
841         env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000);
842   }
843 }
844 
GcOldJobs()845 Status DataServiceDispatcherImpl::GcOldJobs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
846   std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
847   int64 now = env_->NowMicros();
848   for (const auto& job : jobs) {
849     if (job->finished || job->num_clients > 0 ||
850         job->last_client_released_micros < 0 ||
851         now < job->last_client_released_micros +
852                   (config_.job_gc_timeout_ms() * 1000)) {
853       continue;
854     }
855     std::vector<std::shared_ptr<const Task>> tasks;
856     TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks));
857     for (const auto& task : tasks) {
858       if (task->finished) {
859         continue;
860       }
861       Update update;
862       update.mutable_finish_task()->set_task_id(task->task_id);
863       TF_RETURN_IF_ERROR(state_.Apply(update));
864     }
865     DCHECK(job->finished);
866   }
867   return Status::OK();
868 }
869 
GetDatasetDef(int64 dataset_id,std::shared_ptr<const DatasetDef> & dataset_def)870 Status DataServiceDispatcherImpl::GetDatasetDef(
871     int64 dataset_id, std::shared_ptr<const DatasetDef>& dataset_def)
872     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
873   std::shared_ptr<const Dataset> dataset;
874   TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
875   return GetDatasetDef(*dataset, dataset_def);
876 }
877 
GetDatasetDef(const Dataset & dataset,std::shared_ptr<const DatasetDef> & dataset_def)878 Status DataServiceDispatcherImpl::GetDatasetDef(
879     const Dataset& dataset, std::shared_ptr<const DatasetDef>& dataset_def)
880     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
881   std::string key = DatasetKey(dataset.dataset_id, dataset.fingerprint);
882   return dataset_store_->Get(key, dataset_def);
883 }
884 
885 }  // namespace data
886 }  // namespace tensorflow
887