1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/dispatcher_state.h"
16 
17 #include <memory>
18 
19 #include "tensorflow/core/data/service/journal.h"
20 #include "tensorflow/core/data/service/journal.pb.h"
21 #include "tensorflow/core/platform/errors.h"
22 
23 namespace tensorflow {
24 namespace data {
25 
DispatcherState()26 DispatcherState::DispatcherState() {}
27 
Apply(const Update & update)28 Status DispatcherState::Apply(const Update& update) {
29   switch (update.update_type_case()) {
30     case Update::kRegisterDataset:
31       RegisterDataset(update.register_dataset());
32       break;
33     case Update::kRegisterWorker:
34       RegisterWorker(update.register_worker());
35       break;
36     case Update::kCreateJob:
37       CreateJob(update.create_job());
38       break;
39     case Update::kProduceSplit:
40       ProduceSplit(update.produce_split());
41       break;
42     case Update::kAcquireJobClient:
43       AcquireJobClient(update.acquire_job_client());
44       break;
45     case Update::kReleaseJobClient:
46       ReleaseJobClient(update.release_job_client());
47       break;
48     case Update::kCreatePendingTask:
49       CreatePendingTask(update.create_pending_task());
50       break;
51     case Update::kClientHeartbeat:
52       ClientHeartbeat(update.client_heartbeat());
53       break;
54     case Update::kCreateTask:
55       CreateTask(update.create_task());
56       break;
57     case Update::kFinishTask:
58       FinishTask(update.finish_task());
59       break;
60     case Update::UPDATE_TYPE_NOT_SET:
61       return errors::Internal("Update type not set.");
62   }
63 
64   return Status::OK();
65 }
66 
RegisterDataset(const RegisterDatasetUpdate & register_dataset)67 void DispatcherState::RegisterDataset(
68     const RegisterDatasetUpdate& register_dataset) {
69   int64 id = register_dataset.dataset_id();
70   int64 fingerprint = register_dataset.fingerprint();
71   auto dataset = std::make_shared<Dataset>(id, fingerprint);
72   DCHECK(!datasets_by_id_.contains(id));
73   datasets_by_id_[id] = dataset;
74   DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
75   datasets_by_fingerprint_[fingerprint] = dataset;
76   next_available_dataset_id_ = std::max(next_available_dataset_id_, id + 1);
77 }
78 
RegisterWorker(const RegisterWorkerUpdate & register_worker)79 void DispatcherState::RegisterWorker(
80     const RegisterWorkerUpdate& register_worker) {
81   std::string address = register_worker.worker_address();
82   DCHECK(!workers_.contains(address));
83   workers_[address] =
84       std::make_shared<Worker>(address, register_worker.transfer_address());
85   tasks_by_worker_[address] =
86       absl::flat_hash_map<int64, std::shared_ptr<Task>>();
87 }
88 
CreateJob(const CreateJobUpdate & create_job)89 void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
90   int64 job_id = create_job.job_id();
91   absl::optional<NamedJobKey> named_job_key;
92   if (create_job.has_named_job_key()) {
93     named_job_key.emplace(create_job.named_job_key().name(),
94                           create_job.named_job_key().index());
95   }
96   absl::optional<int64> num_consumers;
97   if (create_job.optional_num_consumers_case() ==
98       CreateJobUpdate::kNumConsumers) {
99     num_consumers = create_job.num_consumers();
100   }
101   auto job = std::make_shared<Job>(job_id, create_job.dataset_id(),
102                                    ProcessingMode(create_job.processing_mode()),
103                                    named_job_key, num_consumers);
104   DCHECK(!jobs_.contains(job_id));
105   jobs_[job_id] = job;
106   tasks_by_job_[job_id] = std::vector<std::shared_ptr<Task>>();
107   if (named_job_key.has_value()) {
108     DCHECK(!named_jobs_.contains(named_job_key.value()));
109     named_jobs_[named_job_key.value()] = job;
110   }
111   next_available_job_id_ = std::max(next_available_job_id_, job_id + 1);
112 }
113 
ProduceSplit(const ProduceSplitUpdate & produce_split)114 void DispatcherState::ProduceSplit(const ProduceSplitUpdate& produce_split) {
115   std::shared_ptr<Job> job = jobs_[produce_split.job_id()];
116   DCHECK(job->distributed_epoch_state.has_value());
117   DistributedEpochState& state = job->distributed_epoch_state.value();
118   DCHECK_EQ(produce_split.repetition(), state.repetition);
119   if (produce_split.finished()) {
120     state.repetition++;
121     state.split_provider_index = 0;
122     return;
123   }
124   state.split_provider_index++;
125 }
126 
AcquireJobClient(const AcquireJobClientUpdate & acquire_job_client)127 void DispatcherState::AcquireJobClient(
128     const AcquireJobClientUpdate& acquire_job_client) {
129   int64 job_client_id = acquire_job_client.job_client_id();
130   std::shared_ptr<Job>& job = jobs_for_client_ids_[job_client_id];
131   DCHECK(!job);
132   job = jobs_[acquire_job_client.job_id()];
133   DCHECK(job);
134   job->num_clients++;
135   next_available_job_client_id_ =
136       std::max(next_available_job_client_id_, job_client_id + 1);
137 }
138 
ReleaseJobClient(const ReleaseJobClientUpdate & release_job_client)139 void DispatcherState::ReleaseJobClient(
140     const ReleaseJobClientUpdate& release_job_client) {
141   int64 job_client_id = release_job_client.job_client_id();
142   std::shared_ptr<Job>& job = jobs_for_client_ids_[job_client_id];
143   DCHECK(job);
144   job->num_clients--;
145   DCHECK_GE(job->num_clients, 0);
146   job->last_client_released_micros = release_job_client.time_micros();
147   jobs_for_client_ids_.erase(job_client_id);
148 }
149 
CreatePendingTask(const CreatePendingTaskUpdate & create_pending_task)150 void DispatcherState::CreatePendingTask(
151     const CreatePendingTaskUpdate& create_pending_task) {
152   int64 task_id = create_pending_task.task_id();
153   auto& task = tasks_[task_id];
154   DCHECK_EQ(task, nullptr);
155   auto& job = jobs_[create_pending_task.job_id()];
156   DCHECK_NE(job, nullptr);
157   task =
158       std::make_shared<Task>(task_id, job, create_pending_task.worker_address(),
159                              create_pending_task.transfer_address());
160   job->pending_tasks.emplace(task, create_pending_task.starting_round());
161   tasks_by_worker_[create_pending_task.worker_address()][task->task_id] = task;
162   next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
163 }
164 
ClientHeartbeat(const ClientHeartbeatUpdate & client_heartbeat)165 void DispatcherState::ClientHeartbeat(
166     const ClientHeartbeatUpdate& client_heartbeat) {
167   int64 job_client_id = client_heartbeat.job_client_id();
168   auto& job = jobs_for_client_ids_[job_client_id];
169   DCHECK(!job->pending_tasks.empty());
170   auto& task = job->pending_tasks.front();
171   if (client_heartbeat.has_task_rejected()) {
172     task.failures++;
173     task.ready_consumers.clear();
174     task.target_round = client_heartbeat.task_rejected().new_target_round();
175   }
176   if (client_heartbeat.task_accepted()) {
177     task.ready_consumers.insert(job_client_id);
178     if (task.ready_consumers.size() == job->num_consumers.value()) {
179       task.task->starting_round = task.target_round;
180       tasks_by_job_[job->job_id].push_back(task.task);
181       job->pending_tasks.pop();
182     }
183   }
184 }
185 
CreateTask(const CreateTaskUpdate & create_task)186 void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
187   int64 task_id = create_task.task_id();
188   auto& task = tasks_[task_id];
189   DCHECK_EQ(task, nullptr);
190   auto& job = jobs_[create_task.job_id()];
191   DCHECK_NE(job, nullptr);
192   task = std::make_shared<Task>(task_id, job, create_task.worker_address(),
193                                 create_task.transfer_address());
194   tasks_by_job_[create_task.job_id()].push_back(task);
195   tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
196   next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
197 }
198 
FinishTask(const FinishTaskUpdate & finish_task)199 void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) {
200   VLOG(2) << "Marking task " << finish_task.task_id() << " as finished";
201   int64 task_id = finish_task.task_id();
202   auto& task = tasks_[task_id];
203   DCHECK(task != nullptr);
204   task->finished = true;
205   tasks_by_worker_[task->worker_address].erase(task->task_id);
206   bool all_finished = true;
207   for (const auto& task_for_job : tasks_by_job_[task->job->job_id]) {
208     if (!task_for_job->finished) {
209       all_finished = false;
210     }
211   }
212   VLOG(3) << "Job " << task->job->job_id << " finished: " << all_finished;
213   jobs_[task->job->job_id]->finished = all_finished;
214 }
215 
NextAvailableDatasetId() const216 int64 DispatcherState::NextAvailableDatasetId() const {
217   return next_available_dataset_id_;
218 }
219 
DatasetFromId(int64 id,std::shared_ptr<const Dataset> & dataset) const220 Status DispatcherState::DatasetFromId(
221     int64 id, std::shared_ptr<const Dataset>& dataset) const {
222   auto it = datasets_by_id_.find(id);
223   if (it == datasets_by_id_.end()) {
224     return errors::NotFound("Dataset id ", id, " not found");
225   }
226   dataset = it->second;
227   return Status::OK();
228 }
229 
DatasetFromFingerprint(uint64 fingerprint,std::shared_ptr<const Dataset> & dataset) const230 Status DispatcherState::DatasetFromFingerprint(
231     uint64 fingerprint, std::shared_ptr<const Dataset>& dataset) const {
232   auto it = datasets_by_fingerprint_.find(fingerprint);
233   if (it == datasets_by_fingerprint_.end()) {
234     return errors::NotFound("Dataset fingerprint ", fingerprint, " not found");
235   }
236   dataset = it->second;
237   return Status::OK();
238 }
239 
WorkerFromAddress(const std::string & address,std::shared_ptr<const Worker> & worker) const240 Status DispatcherState::WorkerFromAddress(
241     const std::string& address, std::shared_ptr<const Worker>& worker) const {
242   auto it = workers_.find(address);
243   if (it == workers_.end()) {
244     return errors::NotFound("Worker with address ", address, " not found.");
245   }
246   worker = it->second;
247   return Status::OK();
248 }
249 
250 std::vector<std::shared_ptr<const DispatcherState::Worker>>
ListWorkers() const251 DispatcherState::ListWorkers() const {
252   std::vector<std::shared_ptr<const Worker>> workers;
253   workers.reserve(workers_.size());
254   for (const auto& it : workers_) {
255     workers.push_back(it.second);
256   }
257   return workers;
258 }
259 
260 std::vector<std::shared_ptr<const DispatcherState::Job>>
ListJobs()261 DispatcherState::ListJobs() {
262   std::vector<std::shared_ptr<const DispatcherState::Job>> jobs;
263   jobs.reserve(jobs_.size());
264   for (const auto& it : jobs_) {
265     jobs.push_back(it.second);
266   }
267   return jobs;
268 }
269 
JobFromId(int64 id,std::shared_ptr<const Job> & job) const270 Status DispatcherState::JobFromId(int64 id,
271                                   std::shared_ptr<const Job>& job) const {
272   auto it = jobs_.find(id);
273   if (it == jobs_.end()) {
274     return errors::NotFound("Job id ", id, " not found");
275   }
276   job = it->second;
277   return Status::OK();
278 }
279 
NamedJobByKey(NamedJobKey named_job_key,std::shared_ptr<const Job> & job) const280 Status DispatcherState::NamedJobByKey(NamedJobKey named_job_key,
281                                       std::shared_ptr<const Job>& job) const {
282   auto it = named_jobs_.find(named_job_key);
283   if (it == named_jobs_.end()) {
284     return errors::NotFound("Named job key (", named_job_key.name, ", ",
285                             named_job_key.index, ") not found");
286   }
287   job = it->second;
288   return Status::OK();
289 }
290 
NextAvailableJobId() const291 int64 DispatcherState::NextAvailableJobId() const {
292   return next_available_job_id_;
293 }
294 
JobForJobClientId(int64 job_client_id,std::shared_ptr<const Job> & job)295 Status DispatcherState::JobForJobClientId(int64 job_client_id,
296                                           std::shared_ptr<const Job>& job) {
297   job = jobs_for_client_ids_[job_client_id];
298   if (!job) {
299     return errors::NotFound("Job client id not found: ", job_client_id);
300   }
301   return Status::OK();
302 }
303 
NextAvailableJobClientId() const304 int64 DispatcherState::NextAvailableJobClientId() const {
305   return next_available_job_client_id_;
306 }
307 
TaskFromId(int64 id,std::shared_ptr<const Task> & task) const308 Status DispatcherState::TaskFromId(int64 id,
309                                    std::shared_ptr<const Task>& task) const {
310   auto it = tasks_.find(id);
311   if (it == tasks_.end()) {
312     return errors::NotFound("Task ", id, " not found");
313   }
314   task = it->second;
315   return Status::OK();
316 }
317 
TasksForJob(int64 job_id,std::vector<std::shared_ptr<const Task>> & tasks) const318 Status DispatcherState::TasksForJob(
319     int64 job_id, std::vector<std::shared_ptr<const Task>>& tasks) const {
320   auto it = tasks_by_job_.find(job_id);
321   if (it == tasks_by_job_.end()) {
322     return errors::NotFound("Job ", job_id, " not found");
323   }
324   tasks.clear();
325   tasks.reserve(it->second.size());
326   for (const auto& task : it->second) {
327     tasks.push_back(task);
328   }
329   return Status::OK();
330 }
331 
TasksForWorker(absl::string_view worker_address,std::vector<std::shared_ptr<const Task>> & tasks) const332 Status DispatcherState::TasksForWorker(
333     absl::string_view worker_address,
334     std::vector<std::shared_ptr<const Task>>& tasks) const {
335   auto it = tasks_by_worker_.find(worker_address);
336   if (it == tasks_by_worker_.end()) {
337     return errors::NotFound("Worker ", worker_address, " not found");
338   }
339   const absl::flat_hash_map<int64, std::shared_ptr<Task>>& worker_tasks =
340       it->second;
341   tasks.reserve(worker_tasks.size());
342   for (const auto& task : worker_tasks) {
343     tasks.push_back(task.second);
344   }
345   return Status::OK();
346 }
347 
NextAvailableTaskId() const348 int64 DispatcherState::NextAvailableTaskId() const {
349   return next_available_task_id_;
350 }
351 
352 }  // namespace data
353 }  // namespace tensorflow
354