1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ 17 18 #include <queue> 19 20 #include "absl/container/flat_hash_map.h" 21 #include "tensorflow/core/data/service/common.pb.h" 22 #include "tensorflow/core/data/service/data_service.h" 23 #include "tensorflow/core/data/service/journal.h" 24 #include "tensorflow/core/data/service/journal.pb.h" 25 #include "tensorflow/core/lib/core/status.h" 26 27 namespace tensorflow { 28 namespace data { 29 30 // A class encapsulating the journaled state of the dispatcher. All state 31 // modifications must be done via `Apply`. This helps to ensure that 32 // replaying the journal will allow us to restore the exact same state. 33 // 34 // The following usage pattern will keep the journal in sync with the state of 35 // the dispatcher: 36 // { 37 // mutex_lock l(mu_); 38 // Update update = ... // create an update 39 // dispatcher_state.Apply(update); 40 // journal_writer.write(Update); 41 // // Unlock mu_ 42 // } 43 // 44 // The division of functionality between DispatcherImpl and DispatcherState is 45 // as follows: 46 // - DispatcherImpl is responsible for handling RPC requests, reading from 47 // DispatcherState, and deciding what updates to apply to DispatcherState. 48 // DispatcherImpl handles all synchronization. 49 // - DispatcherState is responsible for making the state changes requested by 50 // DispatcherImpl and for providing DispatcherImpl with read-only access to 51 // the state. 52 // 53 // DispatcherState is thread-compatible but not thread-safe. 54 class DispatcherState { 55 public: 56 DispatcherState(); 57 DispatcherState(const DispatcherState&) = delete; 58 DispatcherState& operator=(const DispatcherState&) = delete; 59 60 // Applies the given update to the dispatcher's state. 61 Status Apply(const Update& update); 62 63 // A dataset registered with the dispatcher. 64 struct Dataset { DatasetDataset65 explicit Dataset(int64 dataset_id, int64 fingerprint) 66 : dataset_id(dataset_id), fingerprint(fingerprint) {} 67 68 const int64 dataset_id; 69 const int64 fingerprint; 70 }; 71 72 // A worker registered with the dispatcher. 73 struct Worker { WorkerWorker74 explicit Worker(const std::string& address, 75 const std::string& transfer_address) 76 : address(address), transfer_address(transfer_address) {} 77 78 const std::string address; 79 const std::string transfer_address; 80 }; 81 82 // A key for identifying a named job. The key contains a user-specified name, 83 // as well as an index describing which iteration of the job we are on. 84 struct NamedJobKey { NamedJobKeyNamedJobKey85 explicit NamedJobKey(absl::string_view name, int64 index) 86 : name(name), index(index) {} 87 88 friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) { 89 return lhs.name == rhs.name && lhs.index == rhs.index; 90 } 91 92 template <typename H> AbslHashValueNamedJobKey93 friend H AbslHashValue(H h, const NamedJobKey& k) { 94 return H::combine(std::move(h), k.name, k.index); 95 } 96 97 const std::string name; 98 const int64 index; 99 }; 100 101 struct DistributedEpochState { 102 // The current repetition. 103 int64 repetition = 0; 104 // Number of splits produced so far by the current split provider. 105 int64 split_provider_index = 0; 106 }; 107 108 struct Task; 109 110 struct PendingTask { PendingTaskPendingTask111 explicit PendingTask(std::shared_ptr<Task> task, int64 target_round) 112 : task(std::move(task)), target_round(target_round) {} 113 114 std::shared_ptr<Task> task; 115 // The target round where we want to insert the task. 116 int64 target_round; 117 // Which consumers have responded that they have successfully blocked 118 // before the target round. 119 absl::flat_hash_set<int64> ready_consumers; 120 // How many times we have failed to add the task. 121 int64 failures = 0; 122 }; 123 124 // A job for processing a dataset. 125 struct Job { JobJob126 explicit Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode, 127 absl::optional<NamedJobKey> named_job_key, 128 absl::optional<int64> num_consumers) 129 : job_id(job_id), 130 dataset_id(dataset_id), 131 processing_mode(processing_mode), 132 named_job_key(named_job_key), 133 num_consumers(num_consumers) { 134 if (processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) { 135 distributed_epoch_state = DistributedEpochState(); 136 } 137 } 138 139 const int64 job_id; 140 const int64 dataset_id; 141 const ProcessingMode processing_mode; 142 const absl::optional<NamedJobKey> named_job_key; 143 absl::optional<DistributedEpochState> distributed_epoch_state; 144 absl::optional<int64> num_consumers; 145 std::queue<PendingTask> pending_tasks; 146 int64 num_clients = 0; 147 int64 last_client_released_micros = -1; 148 bool finished = false; 149 }; 150 151 struct Task { TaskTask152 explicit Task(int64 task_id, const std::shared_ptr<Job>& job, 153 const std::string& worker_address, 154 const std::string& transfer_address) 155 : task_id(task_id), 156 job(job), 157 worker_address(worker_address), 158 transfer_address(transfer_address) {} 159 160 const int64 task_id; 161 const std::shared_ptr<Job> job; 162 const std::string worker_address; 163 const std::string transfer_address; 164 int64 starting_round = 0; 165 bool finished = false; 166 }; 167 168 // Returns the next available dataset id. 169 int64 NextAvailableDatasetId() const; 170 // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset. 171 Status DatasetFromId(int64 id, std::shared_ptr<const Dataset>& dataset) const; 172 // Gets a dataset by fingerprint. Returns NOT_FOUND if there is no such 173 // dataset. 174 Status DatasetFromFingerprint(uint64 fingerprint, 175 std::shared_ptr<const Dataset>& dataset) const; 176 177 // Gets a worker by address. Returns NOT_FOUND if there is no such worker. 178 Status WorkerFromAddress(const std::string& address, 179 std::shared_ptr<const Worker>& worker) const; 180 // Lists all workers registered with the dispatcher. 181 std::vector<std::shared_ptr<const Worker>> ListWorkers() const; 182 183 // Returns the next available job id. 184 int64 NextAvailableJobId() const; 185 // Returns a list of all jobs. 186 std::vector<std::shared_ptr<const Job>> ListJobs(); 187 // Gets a job by id. Returns NOT_FOUND if there is no such job. 188 Status JobFromId(int64 id, std::shared_ptr<const Job>& job) const; 189 // Gets a named job by key. Returns NOT_FOUND if there is no such job. 190 Status NamedJobByKey(NamedJobKey key, std::shared_ptr<const Job>& job) const; 191 192 // Returns the job associated with the given job client id. Returns NOT_FOUND 193 // if the job_client_id is unknown or has been released. 194 Status JobForJobClientId(int64 job_client_id, 195 std::shared_ptr<const Job>& job); 196 // Returns the next available job client id. 197 int64 NextAvailableJobClientId() const; 198 199 // Returns the next available task id. 200 int64 NextAvailableTaskId() const; 201 // Gets a task by id. Returns NOT_FOUND if there is no such task. 202 Status TaskFromId(int64 id, std::shared_ptr<const Task>& task) const; 203 // Stores a list of all tasks for the given job to `tasks`. Returns NOT_FOUND 204 // if there is no such job. 205 Status TasksForJob(int64 job_id, 206 std::vector<std::shared_ptr<const Task>>& tasks) const; 207 // Stores a list of all tasks for the given worker to `tasks`. Returns 208 // NOT_FOUND if there is no such worker. 209 Status TasksForWorker(const absl::string_view worker_address, 210 std::vector<std::shared_ptr<const Task>>& tasks) const; 211 212 private: 213 void RegisterDataset(const RegisterDatasetUpdate& register_dataset); 214 void RegisterWorker(const RegisterWorkerUpdate& register_worker); 215 void CreateJob(const CreateJobUpdate& create_job); 216 void ProduceSplit(const ProduceSplitUpdate& produce_split); 217 void AcquireJobClient(const AcquireJobClientUpdate& acquire_job_client); 218 void ReleaseJobClient(const ReleaseJobClientUpdate& release_job_client); 219 void CreatePendingTask(const CreatePendingTaskUpdate& create_pending_task); 220 void ClientHeartbeat(const ClientHeartbeatUpdate& client_heartbeat); 221 void CreateTask(const CreateTaskUpdate& create_task); 222 void FinishTask(const FinishTaskUpdate& finish_task); 223 224 int64 next_available_dataset_id_ = 1000; 225 // Registered datasets, keyed by dataset ids. 226 absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_; 227 // Registered datasets, keyed by dataset fingerprints. 228 absl::flat_hash_map<uint64, std::shared_ptr<Dataset>> 229 datasets_by_fingerprint_; 230 231 // Registered workers, keyed by address. 232 absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_; 233 234 int64 next_available_job_id_ = 2000; 235 // Jobs, keyed by job ids. 236 absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_; 237 // Named jobs, keyed by their names and indices. Not all jobs have names, so 238 // this is a subset of the jobs stored in `jobs_`. 239 absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_; 240 241 int64 next_available_job_client_id_ = 3000; 242 // Mapping from client ids to the jobs they are associated with. 243 absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_for_client_ids_; 244 245 int64 next_available_task_id_ = 4000; 246 // Tasks, keyed by task ids. 247 absl::flat_hash_map<int64, std::shared_ptr<Task>> tasks_; 248 // Tasks, keyed by job ids. 249 absl::flat_hash_map<int64, std::vector<std::shared_ptr<Task>>> tasks_by_job_; 250 // Tasks, keyed by worker addresses. The values are a map from task id to 251 // task. 252 absl::flat_hash_map<std::string, 253 absl::flat_hash_map<int64, std::shared_ptr<Task>>> 254 tasks_by_worker_; 255 }; 256 257 } // namespace data 258 } // namespace tensorflow 259 260 #endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ 261