1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/dispatcher_state.h"
16
17 #include <memory>
18
19 #include "tensorflow/core/data/service/journal.h"
20 #include "tensorflow/core/data/service/journal.pb.h"
21 #include "tensorflow/core/platform/errors.h"
22
23 namespace tensorflow {
24 namespace data {
25
DispatcherState()26 DispatcherState::DispatcherState() {}
27
Apply(const Update & update)28 Status DispatcherState::Apply(const Update& update) {
29 switch (update.update_type_case()) {
30 case Update::kRegisterDataset:
31 RegisterDataset(update.register_dataset());
32 break;
33 case Update::kRegisterWorker:
34 RegisterWorker(update.register_worker());
35 break;
36 case Update::kCreateJob:
37 CreateJob(update.create_job());
38 break;
39 case Update::kProduceSplit:
40 ProduceSplit(update.produce_split());
41 break;
42 case Update::kAcquireJobClient:
43 AcquireJobClient(update.acquire_job_client());
44 break;
45 case Update::kReleaseJobClient:
46 ReleaseJobClient(update.release_job_client());
47 break;
48 case Update::kCreatePendingTask:
49 CreatePendingTask(update.create_pending_task());
50 break;
51 case Update::kClientHeartbeat:
52 ClientHeartbeat(update.client_heartbeat());
53 break;
54 case Update::kCreateTask:
55 CreateTask(update.create_task());
56 break;
57 case Update::kFinishTask:
58 FinishTask(update.finish_task());
59 break;
60 case Update::UPDATE_TYPE_NOT_SET:
61 return errors::Internal("Update type not set.");
62 }
63
64 return Status::OK();
65 }
66
RegisterDataset(const RegisterDatasetUpdate & register_dataset)67 void DispatcherState::RegisterDataset(
68 const RegisterDatasetUpdate& register_dataset) {
69 int64 id = register_dataset.dataset_id();
70 int64 fingerprint = register_dataset.fingerprint();
71 auto dataset = std::make_shared<Dataset>(id, fingerprint);
72 DCHECK(!datasets_by_id_.contains(id));
73 datasets_by_id_[id] = dataset;
74 DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
75 datasets_by_fingerprint_[fingerprint] = dataset;
76 next_available_dataset_id_ = std::max(next_available_dataset_id_, id + 1);
77 }
78
RegisterWorker(const RegisterWorkerUpdate & register_worker)79 void DispatcherState::RegisterWorker(
80 const RegisterWorkerUpdate& register_worker) {
81 std::string address = register_worker.worker_address();
82 DCHECK(!workers_.contains(address));
83 workers_[address] =
84 std::make_shared<Worker>(address, register_worker.transfer_address());
85 tasks_by_worker_[address] =
86 absl::flat_hash_map<int64, std::shared_ptr<Task>>();
87 }
88
CreateJob(const CreateJobUpdate & create_job)89 void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
90 int64 job_id = create_job.job_id();
91 absl::optional<NamedJobKey> named_job_key;
92 if (create_job.has_named_job_key()) {
93 named_job_key.emplace(create_job.named_job_key().name(),
94 create_job.named_job_key().index());
95 }
96 absl::optional<int64> num_consumers;
97 if (create_job.optional_num_consumers_case() ==
98 CreateJobUpdate::kNumConsumers) {
99 num_consumers = create_job.num_consumers();
100 }
101 auto job = std::make_shared<Job>(job_id, create_job.dataset_id(),
102 ProcessingMode(create_job.processing_mode()),
103 named_job_key, num_consumers);
104 DCHECK(!jobs_.contains(job_id));
105 jobs_[job_id] = job;
106 tasks_by_job_[job_id] = std::vector<std::shared_ptr<Task>>();
107 if (named_job_key.has_value()) {
108 DCHECK(!named_jobs_.contains(named_job_key.value()));
109 named_jobs_[named_job_key.value()] = job;
110 }
111 next_available_job_id_ = std::max(next_available_job_id_, job_id + 1);
112 }
113
ProduceSplit(const ProduceSplitUpdate & produce_split)114 void DispatcherState::ProduceSplit(const ProduceSplitUpdate& produce_split) {
115 std::shared_ptr<Job> job = jobs_[produce_split.job_id()];
116 DCHECK(job->distributed_epoch_state.has_value());
117 DistributedEpochState& state = job->distributed_epoch_state.value();
118 DCHECK_EQ(produce_split.repetition(), state.repetition);
119 if (produce_split.finished()) {
120 state.repetition++;
121 state.split_provider_index = 0;
122 return;
123 }
124 state.split_provider_index++;
125 }
126
AcquireJobClient(const AcquireJobClientUpdate & acquire_job_client)127 void DispatcherState::AcquireJobClient(
128 const AcquireJobClientUpdate& acquire_job_client) {
129 int64 job_client_id = acquire_job_client.job_client_id();
130 std::shared_ptr<Job>& job = jobs_for_client_ids_[job_client_id];
131 DCHECK(!job);
132 job = jobs_[acquire_job_client.job_id()];
133 DCHECK(job);
134 job->num_clients++;
135 next_available_job_client_id_ =
136 std::max(next_available_job_client_id_, job_client_id + 1);
137 }
138
ReleaseJobClient(const ReleaseJobClientUpdate & release_job_client)139 void DispatcherState::ReleaseJobClient(
140 const ReleaseJobClientUpdate& release_job_client) {
141 int64 job_client_id = release_job_client.job_client_id();
142 std::shared_ptr<Job>& job = jobs_for_client_ids_[job_client_id];
143 DCHECK(job);
144 job->num_clients--;
145 DCHECK_GE(job->num_clients, 0);
146 job->last_client_released_micros = release_job_client.time_micros();
147 jobs_for_client_ids_.erase(job_client_id);
148 }
149
CreatePendingTask(const CreatePendingTaskUpdate & create_pending_task)150 void DispatcherState::CreatePendingTask(
151 const CreatePendingTaskUpdate& create_pending_task) {
152 int64 task_id = create_pending_task.task_id();
153 auto& task = tasks_[task_id];
154 DCHECK_EQ(task, nullptr);
155 auto& job = jobs_[create_pending_task.job_id()];
156 DCHECK_NE(job, nullptr);
157 task =
158 std::make_shared<Task>(task_id, job, create_pending_task.worker_address(),
159 create_pending_task.transfer_address());
160 job->pending_tasks.emplace(task, create_pending_task.starting_round());
161 tasks_by_worker_[create_pending_task.worker_address()][task->task_id] = task;
162 next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
163 }
164
ClientHeartbeat(const ClientHeartbeatUpdate & client_heartbeat)165 void DispatcherState::ClientHeartbeat(
166 const ClientHeartbeatUpdate& client_heartbeat) {
167 int64 job_client_id = client_heartbeat.job_client_id();
168 auto& job = jobs_for_client_ids_[job_client_id];
169 DCHECK(!job->pending_tasks.empty());
170 auto& task = job->pending_tasks.front();
171 if (client_heartbeat.has_task_rejected()) {
172 task.failures++;
173 task.ready_consumers.clear();
174 task.target_round = client_heartbeat.task_rejected().new_target_round();
175 }
176 if (client_heartbeat.task_accepted()) {
177 task.ready_consumers.insert(job_client_id);
178 if (task.ready_consumers.size() == job->num_consumers.value()) {
179 task.task->starting_round = task.target_round;
180 tasks_by_job_[job->job_id].push_back(task.task);
181 job->pending_tasks.pop();
182 }
183 }
184 }
185
CreateTask(const CreateTaskUpdate & create_task)186 void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
187 int64 task_id = create_task.task_id();
188 auto& task = tasks_[task_id];
189 DCHECK_EQ(task, nullptr);
190 auto& job = jobs_[create_task.job_id()];
191 DCHECK_NE(job, nullptr);
192 task = std::make_shared<Task>(task_id, job, create_task.worker_address(),
193 create_task.transfer_address());
194 tasks_by_job_[create_task.job_id()].push_back(task);
195 tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
196 next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
197 }
198
FinishTask(const FinishTaskUpdate & finish_task)199 void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) {
200 VLOG(2) << "Marking task " << finish_task.task_id() << " as finished";
201 int64 task_id = finish_task.task_id();
202 auto& task = tasks_[task_id];
203 DCHECK(task != nullptr);
204 task->finished = true;
205 tasks_by_worker_[task->worker_address].erase(task->task_id);
206 bool all_finished = true;
207 for (const auto& task_for_job : tasks_by_job_[task->job->job_id]) {
208 if (!task_for_job->finished) {
209 all_finished = false;
210 }
211 }
212 VLOG(3) << "Job " << task->job->job_id << " finished: " << all_finished;
213 jobs_[task->job->job_id]->finished = all_finished;
214 }
215
NextAvailableDatasetId() const216 int64 DispatcherState::NextAvailableDatasetId() const {
217 return next_available_dataset_id_;
218 }
219
DatasetFromId(int64 id,std::shared_ptr<const Dataset> & dataset) const220 Status DispatcherState::DatasetFromId(
221 int64 id, std::shared_ptr<const Dataset>& dataset) const {
222 auto it = datasets_by_id_.find(id);
223 if (it == datasets_by_id_.end()) {
224 return errors::NotFound("Dataset id ", id, " not found");
225 }
226 dataset = it->second;
227 return Status::OK();
228 }
229
DatasetFromFingerprint(uint64 fingerprint,std::shared_ptr<const Dataset> & dataset) const230 Status DispatcherState::DatasetFromFingerprint(
231 uint64 fingerprint, std::shared_ptr<const Dataset>& dataset) const {
232 auto it = datasets_by_fingerprint_.find(fingerprint);
233 if (it == datasets_by_fingerprint_.end()) {
234 return errors::NotFound("Dataset fingerprint ", fingerprint, " not found");
235 }
236 dataset = it->second;
237 return Status::OK();
238 }
239
WorkerFromAddress(const std::string & address,std::shared_ptr<const Worker> & worker) const240 Status DispatcherState::WorkerFromAddress(
241 const std::string& address, std::shared_ptr<const Worker>& worker) const {
242 auto it = workers_.find(address);
243 if (it == workers_.end()) {
244 return errors::NotFound("Worker with address ", address, " not found.");
245 }
246 worker = it->second;
247 return Status::OK();
248 }
249
250 std::vector<std::shared_ptr<const DispatcherState::Worker>>
ListWorkers() const251 DispatcherState::ListWorkers() const {
252 std::vector<std::shared_ptr<const Worker>> workers;
253 workers.reserve(workers_.size());
254 for (const auto& it : workers_) {
255 workers.push_back(it.second);
256 }
257 return workers;
258 }
259
260 std::vector<std::shared_ptr<const DispatcherState::Job>>
ListJobs()261 DispatcherState::ListJobs() {
262 std::vector<std::shared_ptr<const DispatcherState::Job>> jobs;
263 jobs.reserve(jobs_.size());
264 for (const auto& it : jobs_) {
265 jobs.push_back(it.second);
266 }
267 return jobs;
268 }
269
JobFromId(int64 id,std::shared_ptr<const Job> & job) const270 Status DispatcherState::JobFromId(int64 id,
271 std::shared_ptr<const Job>& job) const {
272 auto it = jobs_.find(id);
273 if (it == jobs_.end()) {
274 return errors::NotFound("Job id ", id, " not found");
275 }
276 job = it->second;
277 return Status::OK();
278 }
279
NamedJobByKey(NamedJobKey named_job_key,std::shared_ptr<const Job> & job) const280 Status DispatcherState::NamedJobByKey(NamedJobKey named_job_key,
281 std::shared_ptr<const Job>& job) const {
282 auto it = named_jobs_.find(named_job_key);
283 if (it == named_jobs_.end()) {
284 return errors::NotFound("Named job key (", named_job_key.name, ", ",
285 named_job_key.index, ") not found");
286 }
287 job = it->second;
288 return Status::OK();
289 }
290
NextAvailableJobId() const291 int64 DispatcherState::NextAvailableJobId() const {
292 return next_available_job_id_;
293 }
294
JobForJobClientId(int64 job_client_id,std::shared_ptr<const Job> & job)295 Status DispatcherState::JobForJobClientId(int64 job_client_id,
296 std::shared_ptr<const Job>& job) {
297 job = jobs_for_client_ids_[job_client_id];
298 if (!job) {
299 return errors::NotFound("Job client id not found: ", job_client_id);
300 }
301 return Status::OK();
302 }
303
NextAvailableJobClientId() const304 int64 DispatcherState::NextAvailableJobClientId() const {
305 return next_available_job_client_id_;
306 }
307
TaskFromId(int64 id,std::shared_ptr<const Task> & task) const308 Status DispatcherState::TaskFromId(int64 id,
309 std::shared_ptr<const Task>& task) const {
310 auto it = tasks_.find(id);
311 if (it == tasks_.end()) {
312 return errors::NotFound("Task ", id, " not found");
313 }
314 task = it->second;
315 return Status::OK();
316 }
317
TasksForJob(int64 job_id,std::vector<std::shared_ptr<const Task>> & tasks) const318 Status DispatcherState::TasksForJob(
319 int64 job_id, std::vector<std::shared_ptr<const Task>>& tasks) const {
320 auto it = tasks_by_job_.find(job_id);
321 if (it == tasks_by_job_.end()) {
322 return errors::NotFound("Job ", job_id, " not found");
323 }
324 tasks.clear();
325 tasks.reserve(it->second.size());
326 for (const auto& task : it->second) {
327 tasks.push_back(task);
328 }
329 return Status::OK();
330 }
331
TasksForWorker(absl::string_view worker_address,std::vector<std::shared_ptr<const Task>> & tasks) const332 Status DispatcherState::TasksForWorker(
333 absl::string_view worker_address,
334 std::vector<std::shared_ptr<const Task>>& tasks) const {
335 auto it = tasks_by_worker_.find(worker_address);
336 if (it == tasks_by_worker_.end()) {
337 return errors::NotFound("Worker ", worker_address, " not found");
338 }
339 const absl::flat_hash_map<int64, std::shared_ptr<Task>>& worker_tasks =
340 it->second;
341 tasks.reserve(worker_tasks.size());
342 for (const auto& task : worker_tasks) {
343 tasks.push_back(task.second);
344 }
345 return Status::OK();
346 }
347
NextAvailableTaskId() const348 int64 DispatcherState::NextAvailableTaskId() const {
349 return next_available_task_id_;
350 }
351
352 } // namespace data
353 } // namespace tensorflow
354