1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/dispatcher_impl.h"
17
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21
22 #ifdef PLATFORM_GOOGLE
23 #include "file/logging/log_lines.h"
24 #endif
25 #include "grpcpp/create_channel.h"
26 #include "grpcpp/impl/codegen/server_context.h"
27 #include "grpcpp/security/credentials.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/memory/memory.h"
30 #include "tensorflow/core/data/service/common.pb.h"
31 #include "tensorflow/core/data/service/credentials_factory.h"
32 #include "tensorflow/core/data/service/data_service.h"
33 #include "tensorflow/core/data/service/dataset_store.h"
34 #include "tensorflow/core/data/service/dispatcher.pb.h"
35 #include "tensorflow/core/data/service/grpc_util.h"
36 #include "tensorflow/core/data/service/journal.h"
37 #include "tensorflow/core/data/service/worker.grpc.pb.h"
38 #include "tensorflow/core/data/standalone.h"
39 #include "tensorflow/core/framework/tensor.pb.h"
40 #include "tensorflow/core/kernels/data/dataset_utils.h"
41 #include "tensorflow/core/kernels/data/hash_utils.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/path.h"
45 #include "tensorflow/core/protobuf/service_config.pb.h"
46 #include "tensorflow/core/public/session_options.h"
47
48 namespace tensorflow {
49 namespace data {
50
51 namespace {
52 // The name of the journal directory inside the dispatcher's working directory.
53 // This name is load-bearing; do not change.
54 constexpr char kJournalDir[] = "tf_data_dispatcher_journal";
55 // The name of the datasets directory inside the dispatcher's working directory.
56 constexpr char kDatasetsDir[] = "datasets";
57
58 constexpr std::array<const char*, 8> kNodeNameSharingOps = {
59 "HashTable",
60 "HashTableV2",
61 "MutableHashTable",
62 "MutableHashTableV2",
63 "MutableDenseHashTable",
64 "MutableDenseHashTableV2",
65 "MutableHashTableOfTensors",
66 "MutableHashTableOfTensorsV2",
67 };
68
69 using Dataset = DispatcherState::Dataset;
70 using Worker = DispatcherState::Worker;
71 using NamedJobKey = DispatcherState::NamedJobKey;
72 using Job = DispatcherState::Job;
73 using Task = DispatcherState::Task;
74
JournalDir(const std::string & work_dir)75 std::string JournalDir(const std::string& work_dir) {
76 return io::JoinPath(work_dir, kJournalDir);
77 }
78
DatasetsDir(const std::string & work_dir)79 std::string DatasetsDir(const std::string& work_dir) {
80 return io::JoinPath(work_dir, kDatasetsDir);
81 }
82
DatasetKey(int64 id,uint64 fingerprint)83 std::string DatasetKey(int64 id, uint64 fingerprint) {
84 return absl::StrCat("id_", id, "_fp_", fingerprint);
85 }
86
CreateWorkerStub(const std::string & address,const std::string & protocol,std::unique_ptr<WorkerService::Stub> & stub)87 Status CreateWorkerStub(const std::string& address, const std::string& protocol,
88 std::unique_ptr<WorkerService::Stub>& stub) {
89 ::grpc::ChannelArguments args;
90 args.SetMaxReceiveMessageSize(-1);
91 std::shared_ptr<::grpc::ChannelCredentials> credentials;
92 TF_RETURN_IF_ERROR(
93 CredentialsFactory::CreateClientCredentials(protocol, &credentials));
94 auto channel = ::grpc::CreateCustomChannel(address, credentials, args);
95 stub = WorkerService::NewStub(channel);
96 return Status::OK();
97 }
98
PrepareGraph(GraphDef * graph)99 void PrepareGraph(GraphDef* graph) {
100 for (NodeDef& node : *graph->mutable_node()) {
101 for (const auto& op : kNodeNameSharingOps) {
102 // Set `use_node_name_sharing` to `true` so that resources aren't deleted
103 // prematurely. Otherwise, resources may be deleted when their ops are
104 // deleted at the end of the GraphRunner::Run used by standalone::Dataset.
105 if (node.op() == op) {
106 (*node.mutable_attr())["use_node_name_sharing"].set_b(true);
107 }
108 if (!node.device().empty()) {
109 *node.mutable_device() = "";
110 }
111 }
112 }
113 StripDevicePlacement(graph->mutable_library());
114 }
115 } // namespace
116
DataServiceDispatcherImpl(const experimental::DispatcherConfig & config)117 DataServiceDispatcherImpl::DataServiceDispatcherImpl(
118 const experimental::DispatcherConfig& config)
119 : config_(config), env_(Env::Default()) {
120 if (config_.work_dir().empty()) {
121 dataset_store_ = absl::make_unique<MemoryDatasetStore>();
122 } else {
123 dataset_store_ = absl::make_unique<FileSystemDatasetStore>(
124 DatasetsDir(config_.work_dir()));
125 }
126 }
127
~DataServiceDispatcherImpl()128 DataServiceDispatcherImpl::~DataServiceDispatcherImpl() {
129 {
130 mutex_lock l(mu_);
131 cancelled_ = true;
132 job_gc_thread_cv_.notify_all();
133 }
134 job_gc_thread_.reset();
135 }
136
Start()137 Status DataServiceDispatcherImpl::Start() {
138 mutex_lock l(mu_);
139 job_gc_thread_ = absl::WrapUnique(
140 env_->StartThread({}, "job-gc-thread", [&] { JobGcThread(); }));
141 if (config_.work_dir().empty()) {
142 if (config_.fault_tolerant_mode()) {
143 return errors::InvalidArgument(
144 "fault_tolerant_mode is True, but no work_dir is configured.");
145 }
146 } else {
147 TF_RETURN_IF_ERROR(
148 env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
149 }
150 if (!config_.fault_tolerant_mode()) {
151 LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
152 "not be able to recover its state on restart.";
153 started_ = true;
154 return Status::OK();
155 }
156 journal_writer_ = absl::make_unique<FileJournalWriter>(
157 env_, JournalDir(config_.work_dir()));
158 LOG(INFO) << "Attempting to restore dispatcher state from journal in "
159 << JournalDir(config_.work_dir());
160 Update update;
161 bool end_of_journal = false;
162 FileJournalReader reader(env_, JournalDir(config_.work_dir()));
163 Status s = reader.Read(update, end_of_journal);
164 if (errors::IsNotFound(s)) {
165 LOG(INFO) << "No journal found. Starting dispatcher from new state.";
166 } else if (!s.ok()) {
167 return s;
168 } else {
169 while (!end_of_journal) {
170 TF_RETURN_IF_ERROR(ApplyWithoutJournaling(update));
171 TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal));
172 }
173 }
174 for (const auto& job : state_.ListJobs()) {
175 if (job->processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
176 TF_RETURN_IF_ERROR(
177 RestoreSplitProvider(*job, split_providers_[job->job_id]));
178 }
179 }
180 // Initialize the journal writer in `Start` so that we fail fast in case it
181 // can't be initialized.
182 TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized());
183 started_ = true;
184 return Status::OK();
185 }
186
RestoreSplitProvider(const Job & job,std::unique_ptr<SplitProvider> & restored)187 Status DataServiceDispatcherImpl::RestoreSplitProvider(
188 const Job& job, std::unique_ptr<SplitProvider>& restored)
189 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
190 int64 index = job.distributed_epoch_state.value().split_provider_index;
191 VLOG(1) << "Restoring split provider for job " << job.job_id << " to index "
192 << index;
193 std::unique_ptr<SplitProvider> split_provider;
194 TF_RETURN_IF_ERROR(MakeSplitProvider(job.dataset_id, split_provider));
195 Tensor unused_tensor;
196 bool unused_end_of_splits;
197 for (int i = 0; i < index; ++i) {
198 TF_RETURN_IF_ERROR(
199 split_provider->GetNext(&unused_tensor, &unused_end_of_splits));
200 }
201 restored = std::move(split_provider);
202 return Status::OK();
203 }
204
WorkerHeartbeat(const WorkerHeartbeatRequest * request,WorkerHeartbeatResponse * response)205 Status DataServiceDispatcherImpl::WorkerHeartbeat(
206 const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
207 TF_RETURN_IF_ERROR(CheckStarted());
208 VLOG(4) << "Received worker heartbeat request from worker "
209 << request->worker_address();
210 mutex_lock l(mu_);
211 const std::string& worker_address = request->worker_address();
212 std::vector<std::shared_ptr<const Task>> correct_tasks;
213 Status s = state_.TasksForWorker(worker_address, correct_tasks);
214 if (!s.ok()) {
215 if (!errors::IsNotFound(s)) {
216 return s;
217 }
218 VLOG(1) << "Registering new worker at address " << worker_address;
219 Update update;
220 update.mutable_register_worker()->set_worker_address(worker_address);
221 update.mutable_register_worker()->set_transfer_address(
222 request->transfer_address());
223 TF_RETURN_IF_ERROR(Apply(update));
224 TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
225 TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, correct_tasks));
226 }
227
228 absl::flat_hash_set<int64> current_tasks;
229 current_tasks.insert(request->current_tasks().cbegin(),
230 request->current_tasks().cend());
231 absl::flat_hash_set<int64> correct_tasks_set;
232
233 for (const auto& task : correct_tasks) {
234 correct_tasks_set.insert(task->task_id);
235 if (current_tasks.contains(task->task_id)) {
236 continue;
237 }
238 TaskDef* task_def = response->add_new_tasks();
239 std::shared_ptr<const Dataset> dataset;
240 TF_RETURN_IF_ERROR(state_.DatasetFromId(task->job->dataset_id, dataset));
241 std::string dataset_key =
242 DatasetKey(dataset->dataset_id, dataset->fingerprint);
243 if (config_.work_dir().empty()) {
244 std::shared_ptr<const DatasetDef> dataset_def;
245 TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
246 *task_def->mutable_dataset_def() = *dataset_def;
247 } else {
248 std::string path =
249 io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
250 task_def->set_path(path);
251 }
252 task_def->set_dataset_id(task->job->dataset_id);
253 task_def->set_job_id(task->job->job_id);
254 task_def->set_task_id(task->task_id);
255 task_def->set_processing_mode(
256 ProcessingModeDef(task->job->processing_mode));
257 if (task->job->num_consumers.has_value()) {
258 task_def->set_num_consumers(task->job->num_consumers.value());
259 }
260 }
261 for (int64 current_task : current_tasks) {
262 if (!correct_tasks_set.contains(current_task)) {
263 response->add_tasks_to_delete(current_task);
264 }
265 }
266
267 VLOG(4) << "Finished worker heartbeat for worker at address "
268 << request->worker_address();
269 return Status::OK();
270 }
271
WorkerUpdate(const WorkerUpdateRequest * request,WorkerUpdateResponse * response)272 Status DataServiceDispatcherImpl::WorkerUpdate(
273 const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
274 TF_RETURN_IF_ERROR(CheckStarted());
275 mutex_lock l(mu_);
276 for (auto& update : request->updates()) {
277 int64 task_id = update.task_id();
278 std::shared_ptr<const Task> task;
279 TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
280 if (update.completed()) {
281 if (task->finished) {
282 VLOG(1) << "Received completion update for already-finished task "
283 << task->task_id << " on worker " << task->worker_address;
284 continue;
285 }
286 Update update;
287 update.mutable_finish_task()->set_task_id(task_id);
288 TF_RETURN_IF_ERROR(Apply(update));
289 VLOG(3) << "Task " << task_id << " from job " << task->job->job_id
290 << " completed";
291 }
292 }
293 return Status::OK();
294 }
295
GetDatasetDef(const GetDatasetDefRequest * request,GetDatasetDefResponse * response)296 Status DataServiceDispatcherImpl::GetDatasetDef(
297 const GetDatasetDefRequest* request, GetDatasetDefResponse* response) {
298 TF_RETURN_IF_ERROR(CheckStarted());
299 mutex_lock l(mu_);
300 std::shared_ptr<const Dataset> dataset;
301 TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), dataset));
302 std::shared_ptr<const DatasetDef> dataset_def;
303 TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
304 *response->mutable_dataset_def() = *dataset_def;
305 return Status::OK();
306 }
307
GetSplit(const GetSplitRequest * request,GetSplitResponse * response)308 Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request,
309 GetSplitResponse* response) {
310 TF_RETURN_IF_ERROR(CheckStarted());
311 mutex_lock l(mu_);
312 int64 job_id = request->job_id();
313 int64 repetition = request->repetition();
314 VLOG(3) << "Received GetSplit request for job " << job_id << ", repetition "
315 << repetition;
316 std::shared_ptr<const Job> job;
317 TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
318 if (!job->distributed_epoch_state.has_value()) {
319 return errors::FailedPrecondition(
320 "Cannot get split for job ", job_id,
321 ", since it is not a distributed_epoch job.");
322 }
323 int64 current_repetition = job->distributed_epoch_state.value().repetition;
324 if (repetition < current_repetition) {
325 response->set_end_of_splits(true);
326 VLOG(3) << "Returning end_of_splits since current reptition "
327 << current_repetition << " is greater than the requested reptition "
328 << repetition;
329 return Status::OK();
330 }
331 SplitProvider* split_provider = split_providers_[job_id].get();
332 DCHECK(split_provider != nullptr);
333 Tensor split;
334 bool end_of_splits = false;
335 TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
336 TF_RETURN_IF_ERROR(RecordSplitProduced(job_id, repetition, end_of_splits));
337 response->set_end_of_splits(end_of_splits);
338 if (end_of_splits) {
339 // Create a new split provider for the next repetition.
340 TF_RETURN_IF_ERROR(
341 MakeSplitProvider(job->dataset_id, split_providers_[job_id]));
342 } else {
343 split.AsProtoTensorContent(response->mutable_split());
344 }
345 VLOG(3) << "Returning from GetSplit, end_of_splits=" << end_of_splits;
346 return Status::OK();
347 }
348
MakeSplitProvider(int64 dataset_id,std::unique_ptr<SplitProvider> & split_provider)349 Status DataServiceDispatcherImpl::MakeSplitProvider(
350 int64 dataset_id, std::unique_ptr<SplitProvider>& split_provider)
351 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
352 std::shared_ptr<const Dataset> dataset;
353 TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
354 std::shared_ptr<const DatasetDef> dataset_def;
355 TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
356 standalone::Dataset::Params params;
357 std::unique_ptr<standalone::Dataset> standalone_dataset;
358 TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
359 params, dataset_def->graph(), &standalone_dataset));
360 TF_RETURN_IF_ERROR(standalone_dataset->MakeSplitProvider(&split_provider));
361 return Status::OK();
362 }
363
GetVersion(const GetVersionRequest * request,GetVersionResponse * response)364 Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request,
365 GetVersionResponse* response) {
366 response->set_version(kDataServiceVersion);
367 return Status::OK();
368 }
369
GetOrRegisterDataset(const GetOrRegisterDatasetRequest * request,GetOrRegisterDatasetResponse * response)370 Status DataServiceDispatcherImpl::GetOrRegisterDataset(
371 const GetOrRegisterDatasetRequest* request,
372 GetOrRegisterDatasetResponse* response) {
373 TF_RETURN_IF_ERROR(CheckStarted());
374 uint64 fingerprint;
375 DatasetDef dataset_def = request->dataset();
376 GraphDef* graph = dataset_def.mutable_graph();
377 PrepareGraph(graph);
378 TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
379
380 mutex_lock l(mu_);
381 #if defined(PLATFORM_GOOGLE)
382 VLOG_LINES(4,
383 absl::StrCat("Registering dataset graph: ", graph->DebugString()));
384 #else
385 VLOG(4) << "Registering dataset graph: " << graph->DebugString();
386 #endif
387 std::shared_ptr<const Dataset> dataset;
388 Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
389 if (s.ok()) {
390 int64 id = dataset->dataset_id;
391 VLOG(3) << "Received duplicate RegisterDataset request with fingerprint "
392 << fingerprint << ". Returning id " << id;
393 response->set_dataset_id(id);
394 return Status::OK();
395 } else if (!errors::IsNotFound(s)) {
396 return s;
397 }
398
399 int64 id;
400 TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def, id));
401 response->set_dataset_id(id);
402 VLOG(3) << "Registered new dataset with id " << id;
403 return Status::OK();
404 }
405
RegisterDataset(uint64 fingerprint,const DatasetDef & dataset,int64 & dataset_id)406 Status DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
407 const DatasetDef& dataset,
408 int64& dataset_id)
409 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
410 dataset_id = state_.NextAvailableDatasetId();
411 Update update;
412 RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
413 register_dataset->set_dataset_id(dataset_id);
414 register_dataset->set_fingerprint(fingerprint);
415 TF_RETURN_IF_ERROR(
416 dataset_store_->Put(DatasetKey(dataset_id, fingerprint), dataset));
417 return Apply(update);
418 }
419
GetOrCreateJob(const GetOrCreateJobRequest * request,GetOrCreateJobResponse * response)420 Status DataServiceDispatcherImpl::GetOrCreateJob(
421 const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
422 TF_RETURN_IF_ERROR(CheckStarted());
423 VLOG(3) << "GetOrCreateJob(" << request->DebugString() << ")";
424 absl::optional<NamedJobKey> key;
425 if (request->has_job_key()) {
426 key.emplace(request->job_key().job_name(),
427 request->job_key().job_name_index());
428 }
429 ProcessingMode requested_processing_mode =
430 ProcessingMode(request->processing_mode());
431 std::shared_ptr<const Job> job;
432 std::vector<std::shared_ptr<const Task>> tasks;
433 {
434 mutex_lock l(mu_);
435 if (key.has_value()) {
436 Status s = state_.NamedJobByKey(key.value(), job);
437 if (s.ok()) {
438 TF_RETURN_IF_ERROR(ValidateMatchingJob(job, requested_processing_mode,
439 request->dataset_id()));
440 int64 job_client_id;
441 TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
442 response->set_job_client_id(job_client_id);
443 VLOG(3) << "Found existing job for name=" << key.value().name
444 << ", index=" << key.value().index
445 << ". job_id: " << job->job_id;
446 return Status::OK();
447 } else if (!errors::IsNotFound(s)) {
448 return s;
449 }
450 }
451 absl::optional<int64> num_consumers;
452 if (request->optional_num_consumers_case() ==
453 GetOrCreateJobRequest::kNumConsumers) {
454 num_consumers = request->num_consumers();
455 }
456 TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(),
457 requested_processing_mode, key, num_consumers,
458 job));
459 int64 job_client_id;
460 TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
461 response->set_job_client_id(job_client_id);
462 TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks));
463 }
464 TF_RETURN_IF_ERROR(AssignTasks(tasks));
465 VLOG(3) << "Created job " << job->job_id << " for CreateJob("
466 << request->DebugString() << ")";
467 return Status::OK();
468 }
469
ReleaseJobClient(const ReleaseJobClientRequest * request,ReleaseJobClientResponse * response)470 Status DataServiceDispatcherImpl::ReleaseJobClient(
471 const ReleaseJobClientRequest* request,
472 ReleaseJobClientResponse* response) {
473 TF_RETURN_IF_ERROR(CheckStarted());
474 mutex_lock l(mu_);
475 int64 job_client_id = request->job_client_id();
476 std::shared_ptr<const Job> job;
477 TF_RETURN_IF_ERROR(state_.JobForJobClientId(job_client_id, job));
478 Update update;
479 ReleaseJobClientUpdate* release_job_client =
480 update.mutable_release_job_client();
481 release_job_client->set_job_client_id(job_client_id);
482 release_job_client->set_time_micros(env_->NowMicros());
483 TF_RETURN_IF_ERROR(Apply(update));
484 return Status::OK();
485 }
486
487 // Validates that the job matches the given processing_mode and dataset_id.
ValidateMatchingJob(std::shared_ptr<const Job> job,ProcessingMode processing_mode,int64 dataset_id)488 Status DataServiceDispatcherImpl::ValidateMatchingJob(
489 std::shared_ptr<const Job> job, ProcessingMode processing_mode,
490 int64 dataset_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
491 DCHECK(job->named_job_key.has_value());
492 std::string job_name = job->named_job_key->name;
493 if (job->processing_mode != processing_mode) {
494 std::string requested = ProcessingModeToString(processing_mode);
495 std::string actual = ProcessingModeToString(job->processing_mode);
496 return errors::FailedPrecondition(
497 "Tried to create a job with name ", job_name, " and processing_mode <",
498 requested,
499 "> but there is already an existing job with that name using "
500 "processing mode <",
501 actual, ">");
502 }
503 return Status::OK();
504 }
505
CreateJob(int64 dataset_id,ProcessingMode processing_mode,absl::optional<NamedJobKey> named_job_key,absl::optional<int64> num_consumers,std::shared_ptr<const Job> & job)506 Status DataServiceDispatcherImpl::CreateJob(
507 int64 dataset_id, ProcessingMode processing_mode,
508 absl::optional<NamedJobKey> named_job_key,
509 absl::optional<int64> num_consumers, std::shared_ptr<const Job>& job)
510 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
511 switch (processing_mode) {
512 case ProcessingMode::PARALLEL_EPOCHS:
513 case ProcessingMode::DISTRIBUTED_EPOCH:
514 break;
515 default:
516 return errors::Internal(
517 absl::StrCat("ProcessingMode ", processing_mode, " not recognized"));
518 }
519 int64 job_id = state_.NextAvailableJobId();
520 if (processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
521 TF_RETURN_IF_ERROR(MakeSplitProvider(dataset_id, split_providers_[job_id]));
522 }
523 Update update;
524 CreateJobUpdate* create_job = update.mutable_create_job();
525 create_job->set_job_id(job_id);
526 create_job->set_dataset_id(dataset_id);
527 create_job->set_processing_mode(ProcessingModeDef(processing_mode));
528 if (named_job_key.has_value()) {
529 NamedJobKeyDef* key = create_job->mutable_named_job_key();
530 key->set_name(named_job_key->name);
531 key->set_index(named_job_key->index);
532 }
533 if (num_consumers.has_value()) {
534 create_job->set_num_consumers(num_consumers.value());
535 }
536 TF_RETURN_IF_ERROR(Apply(update));
537 TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
538 return Status::OK();
539 }
540
CreateTasksForWorker(const std::string & worker_address)541 Status DataServiceDispatcherImpl::CreateTasksForWorker(
542 const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
543 std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
544 for (const auto& job : jobs) {
545 if (job->finished) {
546 continue;
547 }
548 if (job->num_consumers.has_value()) {
549 TF_RETURN_IF_ERROR(CreatePendingTask(job, worker_address));
550 continue;
551 }
552 std::shared_ptr<const Task> task;
553 TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task));
554 }
555 return Status::OK();
556 }
557
AcquireJobClientId(const std::shared_ptr<const Job> & job,int64 & job_client_id)558 Status DataServiceDispatcherImpl::AcquireJobClientId(
559 const std::shared_ptr<const Job>& job, int64& job_client_id)
560 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
561 job_client_id = state_.NextAvailableJobClientId();
562 Update update;
563 AcquireJobClientUpdate* acquire_job_client =
564 update.mutable_acquire_job_client();
565 acquire_job_client->set_job_client_id(job_client_id);
566 acquire_job_client->set_job_id(job->job_id);
567 TF_RETURN_IF_ERROR(Apply(update));
568 return Status::OK();
569 }
570
CreateTasksForJob(std::shared_ptr<const Job> job,std::vector<std::shared_ptr<const Task>> & tasks)571 Status DataServiceDispatcherImpl::CreateTasksForJob(
572 std::shared_ptr<const Job> job,
573 std::vector<std::shared_ptr<const Task>>& tasks)
574 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
575 std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
576 tasks.clear();
577 tasks.reserve(workers.size());
578 for (const auto& worker : workers) {
579 std::shared_ptr<const Task> task;
580 TF_RETURN_IF_ERROR(CreateTask(job, worker->address, task));
581 tasks.push_back(task);
582 }
583 return Status::OK();
584 }
585
CreatePendingTask(std::shared_ptr<const Job> job,const std::string & worker_address)586 Status DataServiceDispatcherImpl::CreatePendingTask(
587 std::shared_ptr<const Job> job, const std::string& worker_address)
588 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
589 int64 task_id = state_.NextAvailableTaskId();
590 Update update;
591 CreatePendingTaskUpdate* create_task = update.mutable_create_pending_task();
592 create_task->set_task_id(task_id);
593 create_task->set_job_id(job->job_id);
594 create_task->set_worker_address(worker_address);
595 create_task->set_starting_round(round_robin_rounds_[job->job_id] + 1);
596 std::shared_ptr<const Worker> worker;
597 TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
598 create_task->set_transfer_address(worker->transfer_address);
599 TF_RETURN_IF_ERROR(Apply(update));
600 return Status::OK();
601 }
602
CreateTask(std::shared_ptr<const Job> job,const std::string & worker_address,std::shared_ptr<const Task> & task)603 Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr<const Job> job,
604 const std::string& worker_address,
605 std::shared_ptr<const Task>& task)
606 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
607 int64 task_id = state_.NextAvailableTaskId();
608 Update update;
609 CreateTaskUpdate* create_task = update.mutable_create_task();
610 create_task->set_task_id(task_id);
611 create_task->set_job_id(job->job_id);
612 create_task->set_worker_address(worker_address);
613 std::shared_ptr<const Worker> worker;
614 TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
615 create_task->set_transfer_address(worker->transfer_address);
616 TF_RETURN_IF_ERROR(Apply(update));
617 TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
618 return Status::OK();
619 }
620
AssignTasks(std::vector<std::shared_ptr<const Task>> tasks)621 Status DataServiceDispatcherImpl::AssignTasks(
622 std::vector<std::shared_ptr<const Task>> tasks) TF_LOCKS_EXCLUDED(mu_) {
623 for (const auto& task : tasks) {
624 TF_RETURN_IF_ERROR(AssignTask(task));
625 }
626 return Status::OK();
627 }
628
GetOrCreateWorkerStub(const std::string & worker_address,WorkerService::Stub * & out_stub)629 Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
630 const std::string& worker_address, WorkerService::Stub*& out_stub)
631 TF_LOCKS_EXCLUDED(mu_) {
632 {
633 mutex_lock l(mu_);
634 auto it = worker_stubs_.find(worker_address);
635 if (it != worker_stubs_.end()) {
636 out_stub = it->second.get();
637 return Status::OK();
638 }
639 }
640 std::unique_ptr<WorkerService::Stub> stub;
641 TF_RETURN_IF_ERROR(
642 CreateWorkerStub(worker_address, config_.protocol(), stub));
643 {
644 mutex_lock l(mu_);
645 // A concurrent call could have already created the stub.
646 auto& worker = worker_stubs_[worker_address];
647 if (worker == nullptr) {
648 worker = std::move(stub);
649 }
650 out_stub = worker.get();
651 }
652 return Status::OK();
653 }
654
AssignTask(std::shared_ptr<const Task> task)655 Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
656 TF_LOCKS_EXCLUDED(mu_) {
657 VLOG(2) << "Started assigning task " << task->task_id << " to worker "
658 << task->worker_address;
659 grpc::ClientContext client_ctx;
660 ProcessTaskRequest req;
661 TaskDef* task_def = req.mutable_task();
662 task_def->set_dataset_id(task->job->dataset_id);
663 task_def->set_job_id(task->job->job_id);
664 {
665 mutex_lock l(mu_);
666 std::shared_ptr<const Dataset> dataset;
667 TF_RETURN_IF_ERROR(state_.DatasetFromId(task->job->dataset_id, dataset));
668 std::string dataset_key =
669 DatasetKey(dataset->dataset_id, dataset->fingerprint);
670 if (config_.work_dir().empty()) {
671 std::shared_ptr<const DatasetDef> dataset_def;
672 TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
673 *task_def->mutable_dataset_def() = *dataset_def;
674 } else {
675 std::string path =
676 io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
677 task_def->set_path(path);
678 }
679 }
680 task_def->set_task_id(task->task_id);
681 task_def->set_processing_mode(ProcessingModeDef(task->job->processing_mode));
682 if (task->job->num_consumers.has_value()) {
683 task_def->set_num_consumers(task->job->num_consumers.value());
684 }
685 ProcessTaskResponse resp;
686 WorkerService::Stub* stub;
687 TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));
688 grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp);
689 if (!s.ok()) {
690 return grpc_util::WrapError(
691 absl::StrCat("Failed to submit task to worker ", task->worker_address),
692 s);
693 }
694 VLOG(2) << "Finished assigning task " << task->task_id << " to worker "
695 << task->worker_address;
696 return Status::OK();
697 }
698
ClientHeartbeat(const ClientHeartbeatRequest * request,ClientHeartbeatResponse * response)699 Status DataServiceDispatcherImpl::ClientHeartbeat(
700 const ClientHeartbeatRequest* request, ClientHeartbeatResponse* response) {
701 TF_RETURN_IF_ERROR(CheckStarted());
702 mutex_lock l(mu_);
703 VLOG(4) << "Received heartbeat from client id " << request->job_client_id();
704 std::shared_ptr<const Job> job;
705 Status s = state_.JobForJobClientId(request->job_client_id(), job);
706 if (errors::IsNotFound(s) && !config_.fault_tolerant_mode()) {
707 return errors::NotFound(
708 "Unknown job client id ", request->job_client_id(),
709 ". The dispatcher is not configured to be fault tolerant, so this "
710 "could be caused by a dispatcher restart.");
711 }
712 TF_RETURN_IF_ERROR(s);
713 if (request->optional_current_round_case() ==
714 ClientHeartbeatRequest::kCurrentRound) {
715 round_robin_rounds_[request->job_client_id()] =
716 std::max(round_robin_rounds_[request->job_client_id()],
717 request->current_round());
718 }
719 if (!job->pending_tasks.empty()) {
720 const auto& task = job->pending_tasks.front();
721 Update update;
722 ClientHeartbeatUpdate* client_heartbeat = update.mutable_client_heartbeat();
723 bool apply_update = false;
724 client_heartbeat->set_job_client_id(request->job_client_id());
725 absl::optional<int64> blocked_round;
726 if (request->optional_blocked_round_case() ==
727 ClientHeartbeatRequest::kBlockedRound) {
728 blocked_round = request->blocked_round();
729 }
730 VLOG(1) << "Handling pending task in job client heartbeat. job_client_id: "
731 << request->job_client_id()
732 << ". current_round: " << request->current_round()
733 << ". blocked_round: " << blocked_round.value_or(-1)
734 << ". target_round: " << task.target_round;
735 if (request->current_round() >= task.target_round) {
736 TaskRejected* rejected = client_heartbeat->mutable_task_rejected();
737 // Exponentially try later and later rounds until consumers all agree.
738 int64 round_offset = 2;
739 for (int i = 0; i < task.failures; ++i) {
740 round_offset *= 2;
741 }
742 rejected->set_new_target_round(
743 round_robin_rounds_[request->job_client_id()] + round_offset);
744 apply_update = true;
745 }
746 if (blocked_round.has_value() &&
747 blocked_round.value() <= task.target_round &&
748 !task.ready_consumers.contains(request->job_client_id())) {
749 client_heartbeat->set_task_accepted(true);
750 apply_update = true;
751 }
752 if (apply_update) {
753 TF_RETURN_IF_ERROR(Apply(update));
754 }
755 }
756 if (!job->pending_tasks.empty()) {
757 response->set_block_round(job->pending_tasks.front().target_round);
758 }
759
760 std::vector<std::shared_ptr<const Task>> tasks;
761 TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks));
762 for (const auto& task : tasks) {
763 TaskInfo* task_info = response->mutable_task_info()->Add();
764 task_info->set_worker_address(task->worker_address);
765 task_info->set_transfer_address(task->transfer_address);
766 task_info->set_task_id(task->task_id);
767 task_info->set_job_id(job->job_id);
768 task_info->set_starting_round(task->starting_round);
769 }
770 response->set_job_finished(job->finished);
771 VLOG(4) << "Found " << response->task_info_size()
772 << " tasks for job client id " << request->job_client_id();
773 return Status::OK();
774 }
775
GetWorkers(const GetWorkersRequest * request,GetWorkersResponse * response)776 Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
777 GetWorkersResponse* response) {
778 TF_RETURN_IF_ERROR(CheckStarted());
779 mutex_lock l(mu_);
780 VLOG(3) << "Enter GetWorkers";
781 std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
782 for (const auto& worker : workers) {
783 WorkerInfo* info = response->add_workers();
784 info->set_address(worker->address);
785 }
786 VLOG(3) << "Returning list of " << response->workers_size()
787 << " workers from GetWorkers";
788 return Status::OK();
789 }
790
CheckStarted()791 Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) {
792 mutex_lock l(mu_);
793 if (!started_) {
794 return errors::Unavailable("Dispatcher has not started yet.");
795 }
796 return Status::OK();
797 }
798
RecordSplitProduced(int64 job_id,int64 repetition,bool finished)799 Status DataServiceDispatcherImpl::RecordSplitProduced(int64 job_id,
800 int64 repetition,
801 bool finished)
802 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
803 Update update;
804 ProduceSplitUpdate* produce_split = update.mutable_produce_split();
805 produce_split->set_job_id(job_id);
806 produce_split->set_repetition(repetition);
807 produce_split->set_finished(finished);
808 return Apply(update);
809 }
810
ApplyWithoutJournaling(const Update & update)811 Status DataServiceDispatcherImpl::ApplyWithoutJournaling(const Update& update)
812 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
813 return state_.Apply(update);
814 }
815
Apply(const Update & update)816 Status DataServiceDispatcherImpl::Apply(const Update& update)
817 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
818 if (journal_writer_.has_value()) {
819 TF_RETURN_IF_ERROR(journal_writer_.value()->Write(update));
820 }
821 return state_.Apply(update);
822 }
823
JobGcThread()824 void DataServiceDispatcherImpl::JobGcThread() {
825 int64 next_check_micros = 0;
826 while (true) {
827 mutex_lock l(mu_);
828 while (!cancelled_ && env_->NowMicros() < next_check_micros) {
829 int64 remaining_micros = next_check_micros - env_->NowMicros();
830 job_gc_thread_cv_.wait_for(l,
831 std::chrono::microseconds(remaining_micros));
832 }
833 if (cancelled_) {
834 return;
835 }
836 Status s = GcOldJobs();
837 if (!s.ok()) {
838 LOG(WARNING) << "Error garbage collecting old jobs: " << s;
839 }
840 next_check_micros =
841 env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000);
842 }
843 }
844
GcOldJobs()845 Status DataServiceDispatcherImpl::GcOldJobs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
846 std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
847 int64 now = env_->NowMicros();
848 for (const auto& job : jobs) {
849 if (job->finished || job->num_clients > 0 ||
850 job->last_client_released_micros < 0 ||
851 now < job->last_client_released_micros +
852 (config_.job_gc_timeout_ms() * 1000)) {
853 continue;
854 }
855 std::vector<std::shared_ptr<const Task>> tasks;
856 TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks));
857 for (const auto& task : tasks) {
858 if (task->finished) {
859 continue;
860 }
861 Update update;
862 update.mutable_finish_task()->set_task_id(task->task_id);
863 TF_RETURN_IF_ERROR(state_.Apply(update));
864 }
865 DCHECK(job->finished);
866 }
867 return Status::OK();
868 }
869
GetDatasetDef(int64 dataset_id,std::shared_ptr<const DatasetDef> & dataset_def)870 Status DataServiceDispatcherImpl::GetDatasetDef(
871 int64 dataset_id, std::shared_ptr<const DatasetDef>& dataset_def)
872 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
873 std::shared_ptr<const Dataset> dataset;
874 TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
875 return GetDatasetDef(*dataset, dataset_def);
876 }
877
GetDatasetDef(const Dataset & dataset,std::shared_ptr<const DatasetDef> & dataset_def)878 Status DataServiceDispatcherImpl::GetDatasetDef(
879 const Dataset& dataset, std::shared_ptr<const DatasetDef>& dataset_def)
880 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
881 std::string key = DatasetKey(dataset.dataset_id, dataset.fingerprint);
882 return dataset_store_->Get(key, dataset_def);
883 }
884
885 } // namespace data
886 } // namespace tensorflow
887