1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
17 
18 #include "tensorflow/core/data/service/common.pb.h"
19 #include "tensorflow/core/data/service/worker.pb.h"
20 #include "tensorflow/core/data/standalone.h"
21 #include "tensorflow/core/platform/status.h"
22 
23 namespace tensorflow {
24 namespace data {
25 
26 // Iterator over a task's elements.
27 class TaskIterator {
28  public:
29   virtual ~TaskIterator() = default;
30   // If the iterator is not yet exhausted, `GetNext` stores the next element in
31   // `element` and sets `end_of_sequence` to `false`. Otherwise, sets
32   // `end_of_sequence to `true`.
33   virtual Status GetNext(std::vector<Tensor>& element,
34                          bool& end_of_sequence) = 0;
35   // Reports the cardinality of the dataset that created this iterator.
36   virtual int64 Cardinality() const = 0;
37 };
38 
39 // Implementation of TaskIterator wrapping a standalone iterator.
40 class StandaloneTaskIterator : public TaskIterator {
41  public:
42   // `dataset` should be the dataset that created `iterator`.
43   // StandaloneTaskIterator takes ownership of the dataset to ensures it
44   // lives as long as `iterator`.
45   StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,
46                          std::unique_ptr<standalone::Iterator> iterator);
47   Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override;
48   int64 Cardinality() const override;
49 
50  private:
51   std::unique_ptr<standalone::Dataset> dataset_;
52   std::unique_ptr<standalone::Iterator> iterator_;
53 };
54 
55 // Interface for providing elements to task consumers.
56 class TaskRunner {
57  public:
58   // Creates a `TaskRunner` and stores it in `out`.
59   static Status Create(const TaskDef& task_def,
60                        std::unique_ptr<TaskIterator> iterator,
61                        std::unique_ptr<TaskRunner>& out);
62   virtual ~TaskRunner() = default;
63   // Gets the next element for the given request.
64   virtual Status GetNext(const GetElementRequest& req,
65                          GetElementResponse& resp) = 0;
66 };
67 
68 // A task runner which provides elements on a first-come first-served basis.
69 // It does not consider which consumer is making the request.
70 class FirstComeFirstServedTaskRunner : public TaskRunner {
71  public:
72   explicit FirstComeFirstServedTaskRunner(
73       std::unique_ptr<TaskIterator> iterator);
74   Status GetNext(const GetElementRequest& req,
75                  GetElementResponse& resp) override;
76 
77  private:
78   std::unique_ptr<TaskIterator> iterator_;
79 };
80 
81 // Thread for prefetching a round worth of elements.
82 class PrefetchThread {
83  public:
84   explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator,
85                           int64 round_size);
86   ~PrefetchThread();
87   // Runs the prefetch thread. It runs until an error is encountered or the
88   // destructor is called.
89   void Run();
90   // Fills `out` with a round of data. Waits for up to `wait_us` micoseconds
91   // before giving up and returning with `out` empty. A negative `wait_us`
92   // signals to wait indefinitely.
93   Status FillBuffer(int64 wait_us, std::vector<std::vector<Tensor>>& out);
94   // Returns the status for any failures encountered by the prefetch thread.
95   Status GetStatus();
96 
97  private:
98   const std::unique_ptr<TaskIterator> iterator_;
99   const int64 round_size_;
100   mutex mu_;
101   // Buffered results for the next round.
102   std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
103   // The status if the prefetch thread fails.
104   Status status_ TF_GUARDED_BY(mu_) = Status::OK();
105   // Thread which constantly tries to fill `buffer_` up with
106   // `num_consumers` elements.
107   std::unique_ptr<Thread> thread_;
108   // Condition variable notified when elements are added to or removed from
109   // `buffer_`, or when `status_` is changed.
110   condition_variable cv_;
111   bool cancelled_ TF_GUARDED_BY(mu_) = false;
112 };
113 
114 // A task runner which enforces round-robin order for consuming a task's
115 // elements. `RoundRobinTaskRunner` provides elements in a series of "rounds".
116 // In each successive round, the runner waits to receive requests from all
117 // consumers. These requests are blocked until all requests arrive. Once all
118 // requests arrive, the runner hands out elements to consumers in order of their
119 // consumer indices.
120 //
121 // Consumers are expected to successively request consecutive element indices,
122 // starting at 0. The same element can be requested multiple times by the same
123 // consumer, as long as the consumer hasn't yet requested the next element (at
124 // the start of each round we discard elements from the previous round).
125 //
126 // If the worker restarts mid-round, a situation arises where some consumers
127 // are requesting element index `n` while others are requesting element index
128 // `n + 1`. To remedy this, the first round after restart may be a partial
129 // round, where we only serve elements to consumers requesting data for element
130 // index `n`, blocking other consumers until the second round.
131 class RoundRobinTaskRunner : public TaskRunner {
132  public:
133   RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
134                        int64 num_consumers);
135 
136   Status GetNext(const GetElementRequest& req,
137                  GetElementResponse& resp) override;
138 
139  private:
140   // Prepares a full round of data. `wait_us` indicates how long to wait before
141   // skipping if a full round of data is not yet ready.
142   Status PrepareFullRound(int64 wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
143   // Prepares a partial round to get consumers back in sync.
144   Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
145   Status ValidateRequest(const GetElementRequest& req);
146   // Prepares data for the next round, blocking until the round is ready to
147   // start.
148   Status PrepareRound(const GetElementRequest& req);
149   const int64 num_consumers_;
150   mutex mu_;
151   // Condition variable notified whenever we start a new round of round-robin.
152   condition_variable new_round_cv_;
153   // Outstanding requests, indexed by round number and then consumer index.
154   absl::flat_hash_map<int64,
155                       absl::flat_hash_map<int64, const GetElementRequest*>>
156       requests_ TF_GUARDED_BY(mu_);
157   // Index of the first round we plan to serve. At startup, this is the minimum
158   // of all requested element indices.
159   int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
160   int64 current_round_ TF_GUARDED_BY(mu_) = -1;
161   bool round_skipped_ TF_GUARDED_BY(mu_) = false;
162   // Buffered results for the current round.
163   std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
164   // Thread which constantly tries to prepare `num_consumers` elements for the
165   // next round.
166   PrefetchThread prefetch_thread_;
167 };
168 
169 }  // namespace data
170 }  // namespace tensorflow
171 
172 #endif  // TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
173