1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
17 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "tensorflow/core/data/service/common.pb.h"
21 #include "tensorflow/core/data/service/data_service.h"
22 #include "tensorflow/core/data/service/dataset_store.h"
23 #include "tensorflow/core/data/service/dispatcher.pb.h"
24 #include "tensorflow/core/data/service/dispatcher_state.h"
25 #include "tensorflow/core/data/service/worker.grpc.pb.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/protobuf/service_config.pb.h"
29 #include "tensorflow/core/public/session.h"
30 
31 namespace tensorflow {
32 namespace data {
33 
34 // A service which coordinates a pool of workers to serve dataset elements over
35 // RPC.
36 //
37 // Glossary:
38 // * Dataset: A definition of how to generate a potentially large collection of
39 //   elements.
40 // * Job: A coordinated phase of reading from the tf.data service. A job
41 //   produces some amount of data, and (potentially multiple) consumers consume
42 //   the data from the job until there is no data left. Each job has a
43 //   ProcessingModeDef which determines what data it produces.
44 // * Task: A job is broken into multiple tasks, which each represent
45 //   iterating over all of or part of the dataset. Workers process tasks.
46 // * Consumer: A process reading from the tf.data service.
47 //
48 // **Adding workers**
49 //
50 // tf.data service supports adding workers mid-job. When a new worker connects
51 // to the dispatcher, the dispatcher creates a new task for the worker, one task
52 // for each outstanding job. Consumers periodically heartbeat to the dispatcher
53 // to learn about new tasks.
54 //
55 // For non-round-robin-reads, there is no coordination among consumers. Each
56 // consumer will start reading from the new task as soon as it learns about the
57 // task from its heartbeat. Round robin reads, on the other hand, require
58 // consumers to read from the same task at each step. This requires coordination
59 // to ensure that all consumers start reading from the new task in the same
60 // round.
61 //
62 // The protocol for adding round robin tasks works as follows:
63 //
64 // - The dispatcher keeps track of which round each round-robin job is on. This
65 //   information is reported by consumers in their heartbeats.
66 // - When a new worker joins and there is an outstanding round-robin job, we
67 //   create a new task for the job and assign it to the worker.
68 //   However, we don't yet report the task in consumer heartbeats.
69 //   We call the task a "pending task" and add it to its job's "pending tasks"
70 //   queue.
71 // - When we create a pending task, we choose a "target round" to try adding
72 //   the task to. The target round is chosen by adding a "target round delta" to
73 //   the latest reported round for the job.
74 // - When a consumer heartbeats for a job and there is a pending task for that
75 //   job, the dispatcher sends a heartbeat response telling the consumer to
76 //   block before reading from the target round.
77 // - When a consumer receives a heartbeat response telling it to block
78 //   (before reading) a round, the consumer try to block the round. If the
79 //   consumer has already started the round, it will too late to block the
80 //   round.
81 // - When consumers heartbeat, they tell the dispatcher their current round and
82 //   whether they have blocked themselves from reading past a certain round. If
83 //   a consumer reports a current round exceeding the target round, the target
84 //   round has failed and needs to be increased. We choose a new target round by
85 //   doubling the previous target round delta. If the consumer reports that it
86 //   has blocked before the target round, we record that the consumer is ready
87 //   to add the new task. Once all consumers are ready to add the new task, we
88 //   remove the task from the pending tasks list and begin reporting the task to
89 //   consumers. We set the "starting_round" field of the task to indicate the
90 //   target round where all consumers should start reading from the task.
91 // - If a new worker joins while there are already pending tasks, a pending
92 //   task for the new worker is created and queued behind the existing tasks.
93 //   The new task won't be considered until all previous pending tasks have been
94 //   successfully added.
95 //
96 // An example of executing this protocol with two consumers could go as follows:
97 // 1. Consumers read up to round 50 and heartbeat that they are on round 50.
98 // 2. A new worker joins. Dispatcher chooses round 51 as the target round.
99 // 3. Consumer 1 heartbeats that its current round is 50. Dispatcher tells it to
100 //    block round 51.
101 // 4. Consumer 2 heartbeats that its current round is 51. Dispatcher realizes
102 //    that it is too late to block round 51 and chooses round 53 as the new
103 //    target round. Dispatcher tells consumer 2 to block round 53.
104 // 5. Consumer 1 heartbeats that its current round is 50 and that it has blocked
105 //    round 51. Dispatcher tells it to block round 53 instead. Dispatcher
106 //    records that consumer 1 is ready to add a task in round 53.
107 // 6. Consumer 2 heartbeats that its current round is 52 and it has blocked
108 //    round 53. Dispatcher realizes that all consumers are blocked on round 53
109 //    or earlier and promotes the task from pending to regular. Dispatcher sends
110 //    consumer 2 a task list containing the new task, and tells consumer 2 that
111 //    it no longer needs to block.
112 // 7. Consumer 1 heartbeats. Dispatcher sends consumer 1 the task list
113 //    containing the new task, and tells it that it no longer needs to block.
114 //
115 class DataServiceDispatcherImpl {
116  public:
117   explicit DataServiceDispatcherImpl(
118       const experimental::DispatcherConfig& config);
119 
120   ~DataServiceDispatcherImpl();
121 
122   // Starts the dispatcher. If there is a journal, this will read from the
123   // journal to restore the dispatcher's state.
124   Status Start();
125 
126   // See dispatcher.proto for API documentation.
127 
128   /// Worker-facing API.
129   Status WorkerHeartbeat(const WorkerHeartbeatRequest* request,
130                          WorkerHeartbeatResponse* response);
131   Status WorkerUpdate(const WorkerUpdateRequest* request,
132                       WorkerUpdateResponse* response);
133   Status GetDatasetDef(const GetDatasetDefRequest* request,
134                        GetDatasetDefResponse* response);
135   Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response);
136 
137   /// Client-facing API.
138   Status GetVersion(const GetVersionRequest* request,
139                     GetVersionResponse* response);
140   Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request,
141                               GetOrRegisterDatasetResponse* response);
142   Status GetOrCreateJob(const GetOrCreateJobRequest* request,
143                         GetOrCreateJobResponse* response);
144   Status ReleaseJobClient(const ReleaseJobClientRequest* request,
145                           ReleaseJobClientResponse* response);
146   Status ClientHeartbeat(const ClientHeartbeatRequest* request,
147                          ClientHeartbeatResponse* response);
148   Status GetWorkers(const GetWorkersRequest* request,
149                     GetWorkersResponse* response);
150 
151  private:
152   // Restores a `SplitProvider` from the state in `job` and stores it in
153   // `restored`.
154   Status RestoreSplitProvider(const DispatcherState::Job& job,
155                               std::unique_ptr<SplitProvider>& restored)
156       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
157   // Makes a split provider for the specified `dataset_id`, and stores it in
158   // `split_provider`.
159   Status MakeSplitProvider(int64 dataset_id,
160                            std::unique_ptr<SplitProvider>& split_provider)
161       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
162   // Registers a dataset with the given fingerprint, storing the new dataset's
163   // id in `dataset_id`.
164   Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset,
165                          int64& dataset_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
166   // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a
167   // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is
168   // stored in `out_stub`.
169   Status GetOrCreateWorkerStub(const std::string& worker_address,
170                                WorkerService::Stub*& out_stub)
171       TF_LOCKS_EXCLUDED(mu_);
172   // Creates a job and stores it in `job`. This method updates the
173   // dispatcher state with the new job, but does not assign tasks to workers.
174   Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
175                    absl::optional<DispatcherState::NamedJobKey> named_job_key,
176                    absl::optional<int64> num_consumers,
177                    std::shared_ptr<const DispatcherState::Job>& job)
178       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
179   // Creates tasks for the specified worker, one task for every unfinished job.
180   Status CreateTasksForWorker(const std::string& worker_address);
181   // Acquires a job client id to read from the given job and sets
182   // `job_client_id`.
183   Status AcquireJobClientId(
184       const std::shared_ptr<const DispatcherState::Job>& job,
185       int64& job_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
186   // Creates one task for each worker, for the given job. The created tasks are
187   // stored in `tasks`. This method only updates dispatcher metadata with the
188   // new tasks, but doesn't assign the tasks to the workers.
189   Status CreateTasksForJob(
190       std::shared_ptr<const DispatcherState::Job> job,
191       std::vector<std::shared_ptr<const DispatcherState::Task>>& tasks)
192       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
193 
194   // Creates a pending task for a round robin job. All consumers need to agree
195   // on which round to add the task in before the pending task can be promoted
196   // to a regular task.
197   Status CreatePendingTask(std::shared_ptr<const DispatcherState::Job> job,
198                            const std::string& worker_address)
199       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
200   // Creates a new task for a job, storing the created task in `task`.
201   Status CreateTask(std::shared_ptr<const DispatcherState::Job> job,
202                     const std::string& worker_address,
203                     std::shared_ptr<const DispatcherState::Task>& task);
204   // Assigns the list of tasks to the workers indicated by their
205   // `worker_address` fields.
206   Status AssignTasks(
207       std::vector<std::shared_ptr<const DispatcherState::Task>> tasks)
208       TF_LOCKS_EXCLUDED(mu_);
209   // Assigns a task to the worker indicated by its `worker_address` field.
210   Status AssignTask(std::shared_ptr<const DispatcherState::Task> task)
211       TF_LOCKS_EXCLUDED(mu_);
212   // Validates that an existing job matches the given processing_mode and
213   // dataset_id, returning an error status describing any difference.
214   Status ValidateMatchingJob(std::shared_ptr<const DispatcherState::Job> job,
215                              ProcessingMode processing_mode, int64 dataset_id)
216       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
217   // Checks that the dispatcher has started, returning UNAVAILABLE if it hasn't.
218   Status CheckStarted() TF_LOCKS_EXCLUDED(mu_);
219   // Records that a split was produced by a call to `GetSplit`.
220   Status RecordSplitProduced(int64 job_id, int64 repetition, bool finished)
221       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
222   // Applies a state update, updating both the journal and the in-memory state.
223   Status Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
224   // Applies a state update, but doesn't update the journal. Only meant to be
225   // used when recovering state when the dispatcher starts.
226   Status ApplyWithoutJournaling(const Update& update)
227       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
228   // A thread which periodically checks for jobs to clean up.
229   void JobGcThread();
230   // Scans for old jobs and marks them as finished.
231   Status GcOldJobs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
232   // Gets a `DatasetDef` from `dataset_store_` for the given dataset id, and
233   // stores it in `dataset_def`.
234   Status GetDatasetDef(int64 dataset_id,
235                        std::shared_ptr<const DatasetDef>& dataset_def)
236       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
237   // Gets a `DatasetDef` from `dataset_store_` for the given dataset, and
238   // stores it in `dataset_def`.
239   Status GetDatasetDef(const DispatcherState::Dataset& dataset,
240                        std::shared_ptr<const DatasetDef>& dataset_def)
241       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
242 
243   const experimental::DispatcherConfig& config_;
244   Env* env_;
245 
246   mutex mu_;
247   bool started_ TF_GUARDED_BY(mu_) = false;
248   bool cancelled_ TF_GUARDED_BY(mu_) = false;
249 
250   // Cached worker stubs for communicating with workers.
251   absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>>
252       worker_stubs_ TF_GUARDED_BY(mu_);
253   // Store of dataset definitions.
254   std::unique_ptr<DatasetStore> dataset_store_ TF_GUARDED_BY(mu_);
255   // Mapping from job id to `SplitProvider`s for jobs with processing mode
256   // DISTRIBUTED_EPOCH.
257   absl::flat_hash_map<int64, std::unique_ptr<SplitProvider>> split_providers_
258       TF_GUARDED_BY(mu_);
259   // Mapping from round robin job id to the round the job is currently on. This
260   // is based on the data provided by client heartbeats, and may be stale.
261   absl::flat_hash_map<int64, int64> round_robin_rounds_ TF_GUARDED_BY(mu_);
262 
263   absl::optional<std::unique_ptr<JournalWriter>> journal_writer_
264       TF_GUARDED_BY(mu_);
265   DispatcherState state_ TF_GUARDED_BY(mu_);
266   // Condition variable for waking up the job gc thread.
267   condition_variable job_gc_thread_cv_;
268   std::unique_ptr<Thread> job_gc_thread_;
269 
270   TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
271 };
272 
273 }  // namespace data
274 }  // namespace tensorflow
275 
276 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
277