1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/task_runner.h"
17 
18 #include "tensorflow/core/data/compression_utils.h"
19 #include "tensorflow/core/data/standalone.h"
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/tensor_util.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/platform/errors.h"
24 
25 namespace tensorflow {
26 namespace data {
27 namespace {
28 // How long to wait for other round-robin consumers before returning with an
29 // Unavailable error. This prevents the server from hanging on shutdown when
30 // some round-robin consumers exit earlier than others.
31 const int64 kTimeoutUs = 60 * 1000 * 1000;  // 1 minute.
32 // Time to wait before skipping a round if data still isn't available.
33 const int64 kWaitBeforeSkipUs = 100 * 1000;  // 100ms.
34 
35 // Interprets `element` as a size-1 vector containing a CompressedElement, and
36 // moves the element into `resp`. Returns an error if `element` is of unexpected
37 // size, type, or shape.
MoveCompressedElement(std::vector<Tensor> && element,GetElementResponse & resp)38 Status MoveCompressedElement(std::vector<Tensor>&& element,
39                              GetElementResponse& resp) {
40   if (element.size() != 1) {
41     return errors::FailedPrecondition(
42         "Expected dataset to produce a single scalar variant tensor, but the "
43         "dataset produced ",
44         element.size(), " outputs");
45   }
46   if (element[0].dtype() != DT_VARIANT) {
47     return errors::FailedPrecondition(
48         "Expected dataset to produce a single scalar variant tensor, but "
49         "the dataset produced a tensor with type ",
50         DataTypeString(element[0].dtype()));
51   }
52   if (!TensorShapeUtils::IsScalar(element[0].shape())) {
53     return errors::FailedPrecondition(
54         "Expected dataset to produce a single scalar variant tensor, but "
55         "the dataset produced a tensor with shape ",
56         element[0].shape());
57   }
58   Variant& variant = element[0].scalar<Variant>()();
59   CompressedElement* compressed = variant.get<CompressedElement>();
60   if (compressed == nullptr) {
61     return errors::FailedPrecondition(
62         "Expected dataset to produce a CompressedElement variant tensor, but "
63         "it produced ",
64         variant.TypeName());
65   }
66   *resp.mutable_compressed_element() = *compressed;
67   return Status::OK();
68 }
69 }  // namespace
70 
StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,std::unique_ptr<standalone::Iterator> iterator)71 StandaloneTaskIterator::StandaloneTaskIterator(
72     std::unique_ptr<standalone::Dataset> dataset,
73     std::unique_ptr<standalone::Iterator> iterator)
74     : dataset_(std::move(dataset)), iterator_(std::move(iterator)) {}
75 
GetNext(std::vector<Tensor> & element,bool & end_of_sequence)76 Status StandaloneTaskIterator::GetNext(std::vector<Tensor>& element,
77                                        bool& end_of_sequence) {
78   return iterator_->GetNext(&element, &end_of_sequence);
79 }
80 
Cardinality() const81 int64 StandaloneTaskIterator::Cardinality() const {
82   return dataset_->Get()->Cardinality();
83 }
84 
Create(const TaskDef & task_def,std::unique_ptr<TaskIterator> iterator,std::unique_ptr<TaskRunner> & out)85 Status TaskRunner::Create(const TaskDef& task_def,
86                           std::unique_ptr<TaskIterator> iterator,
87                           std::unique_ptr<TaskRunner>& out) {
88   if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) {
89     int64 cardinality = iterator->Cardinality();
90     if (cardinality != kInfiniteCardinality &&
91         cardinality != kUnknownCardinality) {
92       return errors::FailedPrecondition(
93           "Round robin reads require that the input dataset has infinite "
94           "cardinality, but the dataset has cardinality ",
95           cardinality,
96           ". Consider adding a `.repeat()` transformation to the dataset.");
97     }
98     out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
99                                                   task_def.num_consumers());
100   } else {
101     out =
102         absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
103   }
104   return Status::OK();
105 }
106 
FirstComeFirstServedTaskRunner(std::unique_ptr<TaskIterator> iterator)107 FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
108     std::unique_ptr<TaskIterator> iterator)
109     : iterator_(std::move(iterator)) {}
110 
GetNext(const GetElementRequest & req,GetElementResponse & resp)111 Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
112                                                GetElementResponse& resp) {
113   std::vector<Tensor> element;
114   bool end_of_task;
115   resp.set_skip_task(false);
116   TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
117   resp.set_end_of_sequence(end_of_task);
118   if (!end_of_task) {
119     return MoveCompressedElement(std::move(element), resp);
120   }
121   return Status::OK();
122 }
123 
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,int64 num_consumers)124 RoundRobinTaskRunner::RoundRobinTaskRunner(
125     std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
126     : num_consumers_(num_consumers),
127       buffer_(num_consumers_),
128       prefetch_thread_(std::move(iterator), num_consumers_) {
129   VLOG(1) << "Creating task runner for distributing data round-robin to "
130           << num_consumers << " consumers";
131 }
132 
ValidateRequest(const GetElementRequest & req)133 Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
134   if (req.consumer_index() < 0 || req.round_index() < 0) {
135     return errors::FailedPrecondition(
136         "RoundRobinTaskRunner needs to know the consumer index and element "
137         "index of each request.");
138   }
139   if (req.consumer_index() >= num_consumers_) {
140     return errors::FailedPrecondition(
141         "Requesting data for consumer index ", req.consumer_index(),
142         ", but the task is configured for only ", num_consumers_, " consumers");
143   }
144   return Status::OK();
145 }
146 
PrepareFullRound(int64 wait_us)147 Status RoundRobinTaskRunner::PrepareFullRound(int64 wait_us)
148     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
149   VLOG(1) << "Preparing full round for index " << current_round_;
150   // This was the last request to arrive, time to start a new round.
151   TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
152   round_skipped_ = buffer_.empty();
153   new_round_cv_.notify_all();
154   return Status::OK();
155 }
156 
PreparePartialRound()157 Status RoundRobinTaskRunner::PreparePartialRound()
158     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
159   VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
160           << " consumers";
161   current_round_ = first_round_;
162   new_round_cv_.notify_all();
163   // Indicates that we need a partial round to get consumers back in sync.
164   auto next_round_request = *(requests_[first_round_ + 1].begin()->second);
165   if (next_round_request.skipped_previous_round()) {
166     VLOG(1) << "Skipping partial round";
167     round_skipped_ = true;
168     return Status::OK();
169   }
170   TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
171   round_skipped_ = false;
172   return Status::OK();
173 }
174 
PrepareRound(const GetElementRequest & req)175 Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
176   mutex_lock l(mu_);
177   first_round_ = std::min(first_round_, req.round_index());
178   absl::flat_hash_map<int64, const GetElementRequest*>& round =
179       requests_[req.round_index()];
180   round[req.consumer_index()] = &req;
181   auto cleanup = gtl::MakeCleanup([&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
182     requests_[req.round_index()].erase(req.consumer_index());
183   });
184   if (current_round_ < req.round_index() && round.size() == num_consumers_) {
185     current_round_ = req.round_index();
186     int64 wait_us = kWaitBeforeSkipUs;
187     if (!req.allow_skip()) {
188       wait_us = -1;
189     }
190     TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
191   }
192   if (current_round_ < 0 &&
193       requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
194           num_consumers_) {
195     TF_RETURN_IF_ERROR(PreparePartialRound());
196   }
197   while (current_round_ < req.round_index()) {
198     TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
199     std::cv_status s =
200         new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
201     if (s == std::cv_status::timeout) {
202       // Clients will retry Unavailable.
203       return errors::Unavailable(
204           "Timeout waiting for other round-robin consumers to be ready.");
205     }
206   }
207   return prefetch_thread_.GetStatus();
208 }
209 
GetNext(const GetElementRequest & req,GetElementResponse & resp)210 Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
211                                      GetElementResponse& resp) {
212   TF_RETURN_IF_ERROR(ValidateRequest(req));
213   resp.set_end_of_sequence(false);
214   VLOG(2) << "Received request from consumer index " << req.consumer_index()
215           << " for round " << req.round_index();
216   TF_RETURN_IF_ERROR(PrepareRound(req));
217   tf_shared_lock l(mu_);
218   resp.set_skip_task(round_skipped_);
219   if (round_skipped_) {
220     VLOG(1) << "Buffer not ready, skipping round " << current_round_
221             << " for consumer " << req.consumer_index();
222     return Status::OK();
223   }
224   std::vector<Tensor> element;
225   for (auto& component : buffer_[req.consumer_index()]) {
226     element.push_back(tensor::DeepCopy(component));
227   }
228   if (VLOG_IS_ON(2)) {
229     int64 size = 0;
230     for (auto& component : element) {
231       size += component.TotalBytes();
232     }
233     VLOG(2) << "Returning to consumer " << req.consumer_index() << " for round "
234             << req.round_index() << ". element size " << size;
235   }
236   return MoveCompressedElement(std::move(element), resp);
237 }
238 
PrefetchThread(std::unique_ptr<TaskIterator> iterator,int64 round_size)239 PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
240                                int64 round_size)
241     : iterator_(std::move(iterator)), round_size_(round_size) {
242   thread_ = absl::WrapUnique(
243       Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
244 }
245 
~PrefetchThread()246 PrefetchThread::~PrefetchThread() {
247   mutex_lock l(mu_);
248   cancelled_ = true;
249   cv_.notify_all();
250 }
251 
Run()252 void PrefetchThread::Run() {
253   while (true) {
254     {
255       mutex_lock l(mu_);
256       while (!cancelled_ && buffer_.size() >= round_size_) {
257         cv_.wait(l);
258       }
259       if (cancelled_) {
260         return;
261       }
262     }
263     std::vector<Tensor> element;
264     bool end_of_sequence;
265     Status s = iterator_->GetNext(element, end_of_sequence);
266     if (!s.ok()) {
267       mutex_lock l(mu_);
268       status_ = s;
269       cv_.notify_all();
270       return;
271     }
272     if (end_of_sequence) {
273       mutex_lock l(mu_);
274       status_ = errors::FailedPrecondition(
275           "Encountered end of sequence on a round-robin read iterator. "
276           "Please ensure that the dataset used for round-robin reading has "
277           "infinite cardinality, e.g. by adding a .repeat() transformation "
278           "at the end.");
279       cv_.notify_all();
280       return;
281     }
282     mutex_lock l(mu_);
283     buffer_.push_back(std::move(element));
284     cv_.notify_all();
285   }
286 }
287 
FillBuffer(int64 wait_us,std::vector<std::vector<Tensor>> & out)288 Status PrefetchThread::FillBuffer(int64 wait_us,
289                                   std::vector<std::vector<Tensor>>& out) {
290   int64 start_us = Env::Default()->NowMicros();
291   out.clear();
292   mutex_lock l(mu_);
293   while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
294     int64 remaining_us = start_us + wait_us - Env::Default()->NowMicros();
295     if (wait_us >= 0 && remaining_us <= 0) {
296       break;
297     }
298     cv_.wait_for(l, std::chrono::microseconds(remaining_us));
299   }
300   TF_RETURN_IF_ERROR(status_);
301   if (cancelled_) {
302     return errors::Cancelled("Prefetch thread cancelled");
303   }
304   if (buffer_.size() < round_size_) {
305     DCHECK_GE(wait_us, 0);
306     return Status::OK();
307   }
308   for (auto& elem : buffer_) {
309     out.push_back(std::move(elem));
310   }
311   buffer_.clear();
312   cv_.notify_all();
313   return Status::OK();
314 }
315 
GetStatus()316 Status PrefetchThread::GetStatus() {
317   mutex_lock l(mu_);
318   return status_;
319 }
320 }  // namespace data
321 }  // namespace tensorflow
322