1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ 17 18 #include "tensorflow/core/data/service/common.pb.h" 19 #include "tensorflow/core/data/service/worker.pb.h" 20 #include "tensorflow/core/data/standalone.h" 21 #include "tensorflow/core/platform/status.h" 22 23 namespace tensorflow { 24 namespace data { 25 26 // Iterator over a task's elements. 27 class TaskIterator { 28 public: 29 virtual ~TaskIterator() = default; 30 // If the iterator is not yet exhausted, `GetNext` stores the next element in 31 // `element` and sets `end_of_sequence` to `false`. Otherwise, sets 32 // `end_of_sequence to `true`. 33 virtual Status GetNext(std::vector<Tensor>& element, 34 bool& end_of_sequence) = 0; 35 // Reports the cardinality of the dataset that created this iterator. 36 virtual int64 Cardinality() const = 0; 37 }; 38 39 // Implementation of TaskIterator wrapping a standalone iterator. 40 class StandaloneTaskIterator : public TaskIterator { 41 public: 42 // `dataset` should be the dataset that created `iterator`. 43 // StandaloneTaskIterator takes ownership of the dataset to ensures it 44 // lives as long as `iterator`. 45 StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset, 46 std::unique_ptr<standalone::Iterator> iterator); 47 Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override; 48 int64 Cardinality() const override; 49 50 private: 51 std::unique_ptr<standalone::Dataset> dataset_; 52 std::unique_ptr<standalone::Iterator> iterator_; 53 }; 54 55 // Interface for providing elements to task consumers. 56 class TaskRunner { 57 public: 58 // Creates a `TaskRunner` and stores it in `out`. 59 static Status Create(const TaskDef& task_def, 60 std::unique_ptr<TaskIterator> iterator, 61 std::unique_ptr<TaskRunner>& out); 62 virtual ~TaskRunner() = default; 63 // Gets the next element for the given request. 64 virtual Status GetNext(const GetElementRequest& req, 65 GetElementResponse& resp) = 0; 66 }; 67 68 // A task runner which provides elements on a first-come first-served basis. 69 // It does not consider which consumer is making the request. 70 class FirstComeFirstServedTaskRunner : public TaskRunner { 71 public: 72 explicit FirstComeFirstServedTaskRunner( 73 std::unique_ptr<TaskIterator> iterator); 74 Status GetNext(const GetElementRequest& req, 75 GetElementResponse& resp) override; 76 77 private: 78 std::unique_ptr<TaskIterator> iterator_; 79 }; 80 81 // Thread for prefetching a round worth of elements. 82 class PrefetchThread { 83 public: 84 explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator, 85 int64 round_size); 86 ~PrefetchThread(); 87 // Runs the prefetch thread. It runs until an error is encountered or the 88 // destructor is called. 89 void Run(); 90 // Fills `out` with a round of data. Waits for up to `wait_us` micoseconds 91 // before giving up and returning with `out` empty. A negative `wait_us` 92 // signals to wait indefinitely. 93 Status FillBuffer(int64 wait_us, std::vector<std::vector<Tensor>>& out); 94 // Returns the status for any failures encountered by the prefetch thread. 95 Status GetStatus(); 96 97 private: 98 const std::unique_ptr<TaskIterator> iterator_; 99 const int64 round_size_; 100 mutex mu_; 101 // Buffered results for the next round. 102 std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_); 103 // The status if the prefetch thread fails. 104 Status status_ TF_GUARDED_BY(mu_) = Status::OK(); 105 // Thread which constantly tries to fill `buffer_` up with 106 // `num_consumers` elements. 107 std::unique_ptr<Thread> thread_; 108 // Condition variable notified when elements are added to or removed from 109 // `buffer_`, or when `status_` is changed. 110 condition_variable cv_; 111 bool cancelled_ TF_GUARDED_BY(mu_) = false; 112 }; 113 114 // A task runner which enforces round-robin order for consuming a task's 115 // elements. `RoundRobinTaskRunner` provides elements in a series of "rounds". 116 // In each successive round, the runner waits to receive requests from all 117 // consumers. These requests are blocked until all requests arrive. Once all 118 // requests arrive, the runner hands out elements to consumers in order of their 119 // consumer indices. 120 // 121 // Consumers are expected to successively request consecutive element indices, 122 // starting at 0. The same element can be requested multiple times by the same 123 // consumer, as long as the consumer hasn't yet requested the next element (at 124 // the start of each round we discard elements from the previous round). 125 // 126 // If the worker restarts mid-round, a situation arises where some consumers 127 // are requesting element index `n` while others are requesting element index 128 // `n + 1`. To remedy this, the first round after restart may be a partial 129 // round, where we only serve elements to consumers requesting data for element 130 // index `n`, blocking other consumers until the second round. 131 class RoundRobinTaskRunner : public TaskRunner { 132 public: 133 RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator, 134 int64 num_consumers); 135 136 Status GetNext(const GetElementRequest& req, 137 GetElementResponse& resp) override; 138 139 private: 140 // Prepares a full round of data. `wait_us` indicates how long to wait before 141 // skipping if a full round of data is not yet ready. 142 Status PrepareFullRound(int64 wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 143 // Prepares a partial round to get consumers back in sync. 144 Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 145 Status ValidateRequest(const GetElementRequest& req); 146 // Prepares data for the next round, blocking until the round is ready to 147 // start. 148 Status PrepareRound(const GetElementRequest& req); 149 const int64 num_consumers_; 150 mutex mu_; 151 // Condition variable notified whenever we start a new round of round-robin. 152 condition_variable new_round_cv_; 153 // Outstanding requests, indexed by round number and then consumer index. 154 absl::flat_hash_map<int64, 155 absl::flat_hash_map<int64, const GetElementRequest*>> 156 requests_ TF_GUARDED_BY(mu_); 157 // Index of the first round we plan to serve. At startup, this is the minimum 158 // of all requested element indices. 159 int64 first_round_ TF_GUARDED_BY(mu_) = kint64max; 160 int64 current_round_ TF_GUARDED_BY(mu_) = -1; 161 bool round_skipped_ TF_GUARDED_BY(mu_) = false; 162 // Buffered results for the current round. 163 std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_); 164 // Thread which constantly tries to prepare `num_consumers` elements for the 165 // next round. 166 PrefetchThread prefetch_thread_; 167 }; 168 169 } // namespace data 170 } // namespace tensorflow 171 172 #endif // TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ 173