1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_
17 
18 #include <queue>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/data/service/common.pb.h"
22 #include "tensorflow/core/data/service/data_service.h"
23 #include "tensorflow/core/data/service/journal.h"
24 #include "tensorflow/core/data/service/journal.pb.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 namespace data {
29 
30 // A class encapsulating the journaled state of the dispatcher. All state
31 // modifications must be done via `Apply`. This helps to ensure that
32 // replaying the journal will allow us to restore the exact same state.
33 //
34 // The following usage pattern will keep the journal in sync with the state of
35 // the dispatcher:
36 // {
37 //   mutex_lock l(mu_);
38 //   Update update = ...  // create an update
39 //   dispatcher_state.Apply(update);
40 //   journal_writer.write(Update);
41 //   // Unlock mu_
42 // }
43 //
44 // The division of functionality between DispatcherImpl and DispatcherState is
45 // as follows:
46 //   - DispatcherImpl is responsible for handling RPC requests, reading from
47 //     DispatcherState, and deciding what updates to apply to DispatcherState.
48 //     DispatcherImpl handles all synchronization.
49 //   - DispatcherState is responsible for making the state changes requested by
50 //     DispatcherImpl and for providing DispatcherImpl with read-only access to
51 //     the state.
52 //
53 // DispatcherState is thread-compatible but not thread-safe.
54 class DispatcherState {
55  public:
56   DispatcherState();
57   DispatcherState(const DispatcherState&) = delete;
58   DispatcherState& operator=(const DispatcherState&) = delete;
59 
60   // Applies the given update to the dispatcher's state.
61   Status Apply(const Update& update);
62 
63   // A dataset registered with the dispatcher.
64   struct Dataset {
DatasetDataset65     explicit Dataset(int64 dataset_id, int64 fingerprint)
66         : dataset_id(dataset_id), fingerprint(fingerprint) {}
67 
68     const int64 dataset_id;
69     const int64 fingerprint;
70   };
71 
72   // A worker registered with the dispatcher.
73   struct Worker {
WorkerWorker74     explicit Worker(const std::string& address,
75                     const std::string& transfer_address)
76         : address(address), transfer_address(transfer_address) {}
77 
78     const std::string address;
79     const std::string transfer_address;
80   };
81 
82   // A key for identifying a named job. The key contains a user-specified name,
83   // as well as an index describing which iteration of the job we are on.
84   struct NamedJobKey {
NamedJobKeyNamedJobKey85     explicit NamedJobKey(absl::string_view name, int64 index)
86         : name(name), index(index) {}
87 
88     friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) {
89       return lhs.name == rhs.name && lhs.index == rhs.index;
90     }
91 
92     template <typename H>
AbslHashValueNamedJobKey93     friend H AbslHashValue(H h, const NamedJobKey& k) {
94       return H::combine(std::move(h), k.name, k.index);
95     }
96 
97     const std::string name;
98     const int64 index;
99   };
100 
101   struct DistributedEpochState {
102     // The current repetition.
103     int64 repetition = 0;
104     // Number of splits produced so far by the current split provider.
105     int64 split_provider_index = 0;
106   };
107 
108   struct Task;
109 
110   struct PendingTask {
PendingTaskPendingTask111     explicit PendingTask(std::shared_ptr<Task> task, int64 target_round)
112         : task(std::move(task)), target_round(target_round) {}
113 
114     std::shared_ptr<Task> task;
115     // The target round where we want to insert the task.
116     int64 target_round;
117     // Which consumers have responded that they have successfully blocked
118     // before the target round.
119     absl::flat_hash_set<int64> ready_consumers;
120     // How many times we have failed to add the task.
121     int64 failures = 0;
122   };
123 
124   // A job for processing a dataset.
125   struct Job {
JobJob126     explicit Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode,
127                  absl::optional<NamedJobKey> named_job_key,
128                  absl::optional<int64> num_consumers)
129         : job_id(job_id),
130           dataset_id(dataset_id),
131           processing_mode(processing_mode),
132           named_job_key(named_job_key),
133           num_consumers(num_consumers) {
134       if (processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
135         distributed_epoch_state = DistributedEpochState();
136       }
137     }
138 
139     const int64 job_id;
140     const int64 dataset_id;
141     const ProcessingMode processing_mode;
142     const absl::optional<NamedJobKey> named_job_key;
143     absl::optional<DistributedEpochState> distributed_epoch_state;
144     absl::optional<int64> num_consumers;
145     std::queue<PendingTask> pending_tasks;
146     int64 num_clients = 0;
147     int64 last_client_released_micros = -1;
148     bool finished = false;
149   };
150 
151   struct Task {
TaskTask152     explicit Task(int64 task_id, const std::shared_ptr<Job>& job,
153                   const std::string& worker_address,
154                   const std::string& transfer_address)
155         : task_id(task_id),
156           job(job),
157           worker_address(worker_address),
158           transfer_address(transfer_address) {}
159 
160     const int64 task_id;
161     const std::shared_ptr<Job> job;
162     const std::string worker_address;
163     const std::string transfer_address;
164     int64 starting_round = 0;
165     bool finished = false;
166   };
167 
168   // Returns the next available dataset id.
169   int64 NextAvailableDatasetId() const;
170   // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset.
171   Status DatasetFromId(int64 id, std::shared_ptr<const Dataset>& dataset) const;
172   // Gets a dataset by fingerprint. Returns NOT_FOUND if there is no such
173   // dataset.
174   Status DatasetFromFingerprint(uint64 fingerprint,
175                                 std::shared_ptr<const Dataset>& dataset) const;
176 
177   // Gets a worker by address. Returns NOT_FOUND if there is no such worker.
178   Status WorkerFromAddress(const std::string& address,
179                            std::shared_ptr<const Worker>& worker) const;
180   // Lists all workers registered with the dispatcher.
181   std::vector<std::shared_ptr<const Worker>> ListWorkers() const;
182 
183   // Returns the next available job id.
184   int64 NextAvailableJobId() const;
185   // Returns a list of all jobs.
186   std::vector<std::shared_ptr<const Job>> ListJobs();
187   // Gets a job by id. Returns NOT_FOUND if there is no such job.
188   Status JobFromId(int64 id, std::shared_ptr<const Job>& job) const;
189   // Gets a named job by key. Returns NOT_FOUND if there is no such job.
190   Status NamedJobByKey(NamedJobKey key, std::shared_ptr<const Job>& job) const;
191 
192   // Returns the job associated with the given job client id. Returns NOT_FOUND
193   // if the job_client_id is unknown or has been released.
194   Status JobForJobClientId(int64 job_client_id,
195                            std::shared_ptr<const Job>& job);
196   // Returns the next available job client id.
197   int64 NextAvailableJobClientId() const;
198 
199   // Returns the next available task id.
200   int64 NextAvailableTaskId() const;
201   // Gets a task by id. Returns NOT_FOUND if there is no such task.
202   Status TaskFromId(int64 id, std::shared_ptr<const Task>& task) const;
203   // Stores a list of all tasks for the given job to `tasks`. Returns NOT_FOUND
204   // if there is no such job.
205   Status TasksForJob(int64 job_id,
206                      std::vector<std::shared_ptr<const Task>>& tasks) const;
207   // Stores a list of all tasks for the given worker to `tasks`. Returns
208   // NOT_FOUND if there is no such worker.
209   Status TasksForWorker(const absl::string_view worker_address,
210                         std::vector<std::shared_ptr<const Task>>& tasks) const;
211 
212  private:
213   void RegisterDataset(const RegisterDatasetUpdate& register_dataset);
214   void RegisterWorker(const RegisterWorkerUpdate& register_worker);
215   void CreateJob(const CreateJobUpdate& create_job);
216   void ProduceSplit(const ProduceSplitUpdate& produce_split);
217   void AcquireJobClient(const AcquireJobClientUpdate& acquire_job_client);
218   void ReleaseJobClient(const ReleaseJobClientUpdate& release_job_client);
219   void CreatePendingTask(const CreatePendingTaskUpdate& create_pending_task);
220   void ClientHeartbeat(const ClientHeartbeatUpdate& client_heartbeat);
221   void CreateTask(const CreateTaskUpdate& create_task);
222   void FinishTask(const FinishTaskUpdate& finish_task);
223 
224   int64 next_available_dataset_id_ = 1000;
225   // Registered datasets, keyed by dataset ids.
226   absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_;
227   // Registered datasets, keyed by dataset fingerprints.
228   absl::flat_hash_map<uint64, std::shared_ptr<Dataset>>
229       datasets_by_fingerprint_;
230 
231   // Registered workers, keyed by address.
232   absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_;
233 
234   int64 next_available_job_id_ = 2000;
235   // Jobs, keyed by job ids.
236   absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_;
237   // Named jobs, keyed by their names and indices. Not all jobs have names, so
238   // this is a subset of the jobs stored in `jobs_`.
239   absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_;
240 
241   int64 next_available_job_client_id_ = 3000;
242   // Mapping from client ids to the jobs they are associated with.
243   absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_for_client_ids_;
244 
245   int64 next_available_task_id_ = 4000;
246   // Tasks, keyed by task ids.
247   absl::flat_hash_map<int64, std::shared_ptr<Task>> tasks_;
248   // Tasks, keyed by job ids.
249   absl::flat_hash_map<int64, std::vector<std::shared_ptr<Task>>> tasks_by_job_;
250   // Tasks, keyed by worker addresses. The values are a map from task id to
251   // task.
252   absl::flat_hash_map<std::string,
253                       absl::flat_hash_map<int64, std::shared_ptr<Task>>>
254       tasks_by_worker_;
255 };
256 
257 }  // namespace data
258 }  // namespace tensorflow
259 
260 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_
261