1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ 17 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "tensorflow/core/data/service/common.pb.h" 21 #include "tensorflow/core/data/service/data_service.h" 22 #include "tensorflow/core/data/service/dataset_store.h" 23 #include "tensorflow/core/data/service/dispatcher.pb.h" 24 #include "tensorflow/core/data/service/dispatcher_state.h" 25 #include "tensorflow/core/data/service/worker.grpc.pb.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/platform/mutex.h" 28 #include "tensorflow/core/protobuf/service_config.pb.h" 29 #include "tensorflow/core/public/session.h" 30 31 namespace tensorflow { 32 namespace data { 33 34 // A service which coordinates a pool of workers to serve dataset elements over 35 // RPC. 36 // 37 // Glossary: 38 // * Dataset: A definition of how to generate a potentially large collection of 39 // elements. 40 // * Job: A coordinated phase of reading from the tf.data service. A job 41 // produces some amount of data, and (potentially multiple) consumers consume 42 // the data from the job until there is no data left. Each job has a 43 // ProcessingModeDef which determines what data it produces. 44 // * Task: A job is broken into multiple tasks, which each represent 45 // iterating over all of or part of the dataset. Workers process tasks. 46 // * Consumer: A process reading from the tf.data service. 47 // 48 // **Adding workers** 49 // 50 // tf.data service supports adding workers mid-job. When a new worker connects 51 // to the dispatcher, the dispatcher creates a new task for the worker, one task 52 // for each outstanding job. Consumers periodically heartbeat to the dispatcher 53 // to learn about new tasks. 54 // 55 // For non-round-robin-reads, there is no coordination among consumers. Each 56 // consumer will start reading from the new task as soon as it learns about the 57 // task from its heartbeat. Round robin reads, on the other hand, require 58 // consumers to read from the same task at each step. This requires coordination 59 // to ensure that all consumers start reading from the new task in the same 60 // round. 61 // 62 // The protocol for adding round robin tasks works as follows: 63 // 64 // - The dispatcher keeps track of which round each round-robin job is on. This 65 // information is reported by consumers in their heartbeats. 66 // - When a new worker joins and there is an outstanding round-robin job, we 67 // create a new task for the job and assign it to the worker. 68 // However, we don't yet report the task in consumer heartbeats. 69 // We call the task a "pending task" and add it to its job's "pending tasks" 70 // queue. 71 // - When we create a pending task, we choose a "target round" to try adding 72 // the task to. The target round is chosen by adding a "target round delta" to 73 // the latest reported round for the job. 74 // - When a consumer heartbeats for a job and there is a pending task for that 75 // job, the dispatcher sends a heartbeat response telling the consumer to 76 // block before reading from the target round. 77 // - When a consumer receives a heartbeat response telling it to block 78 // (before reading) a round, the consumer try to block the round. If the 79 // consumer has already started the round, it will too late to block the 80 // round. 81 // - When consumers heartbeat, they tell the dispatcher their current round and 82 // whether they have blocked themselves from reading past a certain round. If 83 // a consumer reports a current round exceeding the target round, the target 84 // round has failed and needs to be increased. We choose a new target round by 85 // doubling the previous target round delta. If the consumer reports that it 86 // has blocked before the target round, we record that the consumer is ready 87 // to add the new task. Once all consumers are ready to add the new task, we 88 // remove the task from the pending tasks list and begin reporting the task to 89 // consumers. We set the "starting_round" field of the task to indicate the 90 // target round where all consumers should start reading from the task. 91 // - If a new worker joins while there are already pending tasks, a pending 92 // task for the new worker is created and queued behind the existing tasks. 93 // The new task won't be considered until all previous pending tasks have been 94 // successfully added. 95 // 96 // An example of executing this protocol with two consumers could go as follows: 97 // 1. Consumers read up to round 50 and heartbeat that they are on round 50. 98 // 2. A new worker joins. Dispatcher chooses round 51 as the target round. 99 // 3. Consumer 1 heartbeats that its current round is 50. Dispatcher tells it to 100 // block round 51. 101 // 4. Consumer 2 heartbeats that its current round is 51. Dispatcher realizes 102 // that it is too late to block round 51 and chooses round 53 as the new 103 // target round. Dispatcher tells consumer 2 to block round 53. 104 // 5. Consumer 1 heartbeats that its current round is 50 and that it has blocked 105 // round 51. Dispatcher tells it to block round 53 instead. Dispatcher 106 // records that consumer 1 is ready to add a task in round 53. 107 // 6. Consumer 2 heartbeats that its current round is 52 and it has blocked 108 // round 53. Dispatcher realizes that all consumers are blocked on round 53 109 // or earlier and promotes the task from pending to regular. Dispatcher sends 110 // consumer 2 a task list containing the new task, and tells consumer 2 that 111 // it no longer needs to block. 112 // 7. Consumer 1 heartbeats. Dispatcher sends consumer 1 the task list 113 // containing the new task, and tells it that it no longer needs to block. 114 // 115 class DataServiceDispatcherImpl { 116 public: 117 explicit DataServiceDispatcherImpl( 118 const experimental::DispatcherConfig& config); 119 120 ~DataServiceDispatcherImpl(); 121 122 // Starts the dispatcher. If there is a journal, this will read from the 123 // journal to restore the dispatcher's state. 124 Status Start(); 125 126 // See dispatcher.proto for API documentation. 127 128 /// Worker-facing API. 129 Status WorkerHeartbeat(const WorkerHeartbeatRequest* request, 130 WorkerHeartbeatResponse* response); 131 Status WorkerUpdate(const WorkerUpdateRequest* request, 132 WorkerUpdateResponse* response); 133 Status GetDatasetDef(const GetDatasetDefRequest* request, 134 GetDatasetDefResponse* response); 135 Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response); 136 137 /// Client-facing API. 138 Status GetVersion(const GetVersionRequest* request, 139 GetVersionResponse* response); 140 Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request, 141 GetOrRegisterDatasetResponse* response); 142 Status GetOrCreateJob(const GetOrCreateJobRequest* request, 143 GetOrCreateJobResponse* response); 144 Status ReleaseJobClient(const ReleaseJobClientRequest* request, 145 ReleaseJobClientResponse* response); 146 Status ClientHeartbeat(const ClientHeartbeatRequest* request, 147 ClientHeartbeatResponse* response); 148 Status GetWorkers(const GetWorkersRequest* request, 149 GetWorkersResponse* response); 150 151 private: 152 // Restores a `SplitProvider` from the state in `job` and stores it in 153 // `restored`. 154 Status RestoreSplitProvider(const DispatcherState::Job& job, 155 std::unique_ptr<SplitProvider>& restored) 156 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 157 // Makes a split provider for the specified `dataset_id`, and stores it in 158 // `split_provider`. 159 Status MakeSplitProvider(int64 dataset_id, 160 std::unique_ptr<SplitProvider>& split_provider) 161 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 162 // Registers a dataset with the given fingerprint, storing the new dataset's 163 // id in `dataset_id`. 164 Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset, 165 int64& dataset_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 166 // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a 167 // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is 168 // stored in `out_stub`. 169 Status GetOrCreateWorkerStub(const std::string& worker_address, 170 WorkerService::Stub*& out_stub) 171 TF_LOCKS_EXCLUDED(mu_); 172 // Creates a job and stores it in `job`. This method updates the 173 // dispatcher state with the new job, but does not assign tasks to workers. 174 Status CreateJob(int64 dataset_id, ProcessingMode processing_mode, 175 absl::optional<DispatcherState::NamedJobKey> named_job_key, 176 absl::optional<int64> num_consumers, 177 std::shared_ptr<const DispatcherState::Job>& job) 178 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 179 // Creates tasks for the specified worker, one task for every unfinished job. 180 Status CreateTasksForWorker(const std::string& worker_address); 181 // Acquires a job client id to read from the given job and sets 182 // `job_client_id`. 183 Status AcquireJobClientId( 184 const std::shared_ptr<const DispatcherState::Job>& job, 185 int64& job_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 186 // Creates one task for each worker, for the given job. The created tasks are 187 // stored in `tasks`. This method only updates dispatcher metadata with the 188 // new tasks, but doesn't assign the tasks to the workers. 189 Status CreateTasksForJob( 190 std::shared_ptr<const DispatcherState::Job> job, 191 std::vector<std::shared_ptr<const DispatcherState::Task>>& tasks) 192 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 193 194 // Creates a pending task for a round robin job. All consumers need to agree 195 // on which round to add the task in before the pending task can be promoted 196 // to a regular task. 197 Status CreatePendingTask(std::shared_ptr<const DispatcherState::Job> job, 198 const std::string& worker_address) 199 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 200 // Creates a new task for a job, storing the created task in `task`. 201 Status CreateTask(std::shared_ptr<const DispatcherState::Job> job, 202 const std::string& worker_address, 203 std::shared_ptr<const DispatcherState::Task>& task); 204 // Assigns the list of tasks to the workers indicated by their 205 // `worker_address` fields. 206 Status AssignTasks( 207 std::vector<std::shared_ptr<const DispatcherState::Task>> tasks) 208 TF_LOCKS_EXCLUDED(mu_); 209 // Assigns a task to the worker indicated by its `worker_address` field. 210 Status AssignTask(std::shared_ptr<const DispatcherState::Task> task) 211 TF_LOCKS_EXCLUDED(mu_); 212 // Validates that an existing job matches the given processing_mode and 213 // dataset_id, returning an error status describing any difference. 214 Status ValidateMatchingJob(std::shared_ptr<const DispatcherState::Job> job, 215 ProcessingMode processing_mode, int64 dataset_id) 216 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 217 // Checks that the dispatcher has started, returning UNAVAILABLE if it hasn't. 218 Status CheckStarted() TF_LOCKS_EXCLUDED(mu_); 219 // Records that a split was produced by a call to `GetSplit`. 220 Status RecordSplitProduced(int64 job_id, int64 repetition, bool finished) 221 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 222 // Applies a state update, updating both the journal and the in-memory state. 223 Status Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 224 // Applies a state update, but doesn't update the journal. Only meant to be 225 // used when recovering state when the dispatcher starts. 226 Status ApplyWithoutJournaling(const Update& update) 227 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 228 // A thread which periodically checks for jobs to clean up. 229 void JobGcThread(); 230 // Scans for old jobs and marks them as finished. 231 Status GcOldJobs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 232 // Gets a `DatasetDef` from `dataset_store_` for the given dataset id, and 233 // stores it in `dataset_def`. 234 Status GetDatasetDef(int64 dataset_id, 235 std::shared_ptr<const DatasetDef>& dataset_def) 236 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 237 // Gets a `DatasetDef` from `dataset_store_` for the given dataset, and 238 // stores it in `dataset_def`. 239 Status GetDatasetDef(const DispatcherState::Dataset& dataset, 240 std::shared_ptr<const DatasetDef>& dataset_def) 241 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 242 243 const experimental::DispatcherConfig& config_; 244 Env* env_; 245 246 mutex mu_; 247 bool started_ TF_GUARDED_BY(mu_) = false; 248 bool cancelled_ TF_GUARDED_BY(mu_) = false; 249 250 // Cached worker stubs for communicating with workers. 251 absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>> 252 worker_stubs_ TF_GUARDED_BY(mu_); 253 // Store of dataset definitions. 254 std::unique_ptr<DatasetStore> dataset_store_ TF_GUARDED_BY(mu_); 255 // Mapping from job id to `SplitProvider`s for jobs with processing mode 256 // DISTRIBUTED_EPOCH. 257 absl::flat_hash_map<int64, std::unique_ptr<SplitProvider>> split_providers_ 258 TF_GUARDED_BY(mu_); 259 // Mapping from round robin job id to the round the job is currently on. This 260 // is based on the data provided by client heartbeats, and may be stale. 261 absl::flat_hash_map<int64, int64> round_robin_rounds_ TF_GUARDED_BY(mu_); 262 263 absl::optional<std::unique_ptr<JournalWriter>> journal_writer_ 264 TF_GUARDED_BY(mu_); 265 DispatcherState state_ TF_GUARDED_BY(mu_); 266 // Condition variable for waking up the job gc thread. 267 condition_variable job_gc_thread_cv_; 268 std::unique_ptr<Thread> job_gc_thread_; 269 270 TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl); 271 }; 272 273 } // namespace data 274 } // namespace tensorflow 275 276 #endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ 277