1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/task_runner.h"
17
18 #include "tensorflow/core/data/compression_utils.h"
19 #include "tensorflow/core/data/standalone.h"
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/tensor_util.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/platform/errors.h"
24
25 namespace tensorflow {
26 namespace data {
27 namespace {
28 // How long to wait for other round-robin consumers before returning with an
29 // Unavailable error. This prevents the server from hanging on shutdown when
30 // some round-robin consumers exit earlier than others.
31 const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute.
32 // Time to wait before skipping a round if data still isn't available.
33 const int64 kWaitBeforeSkipUs = 100 * 1000; // 100ms.
34
35 // Interprets `element` as a size-1 vector containing a CompressedElement, and
36 // moves the element into `resp`. Returns an error if `element` is of unexpected
37 // size, type, or shape.
MoveCompressedElement(std::vector<Tensor> && element,GetElementResponse & resp)38 Status MoveCompressedElement(std::vector<Tensor>&& element,
39 GetElementResponse& resp) {
40 if (element.size() != 1) {
41 return errors::FailedPrecondition(
42 "Expected dataset to produce a single scalar variant tensor, but the "
43 "dataset produced ",
44 element.size(), " outputs");
45 }
46 if (element[0].dtype() != DT_VARIANT) {
47 return errors::FailedPrecondition(
48 "Expected dataset to produce a single scalar variant tensor, but "
49 "the dataset produced a tensor with type ",
50 DataTypeString(element[0].dtype()));
51 }
52 if (!TensorShapeUtils::IsScalar(element[0].shape())) {
53 return errors::FailedPrecondition(
54 "Expected dataset to produce a single scalar variant tensor, but "
55 "the dataset produced a tensor with shape ",
56 element[0].shape());
57 }
58 Variant& variant = element[0].scalar<Variant>()();
59 CompressedElement* compressed = variant.get<CompressedElement>();
60 if (compressed == nullptr) {
61 return errors::FailedPrecondition(
62 "Expected dataset to produce a CompressedElement variant tensor, but "
63 "it produced ",
64 variant.TypeName());
65 }
66 *resp.mutable_compressed_element() = *compressed;
67 return Status::OK();
68 }
69 } // namespace
70
StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,std::unique_ptr<standalone::Iterator> iterator)71 StandaloneTaskIterator::StandaloneTaskIterator(
72 std::unique_ptr<standalone::Dataset> dataset,
73 std::unique_ptr<standalone::Iterator> iterator)
74 : dataset_(std::move(dataset)), iterator_(std::move(iterator)) {}
75
GetNext(std::vector<Tensor> & element,bool & end_of_sequence)76 Status StandaloneTaskIterator::GetNext(std::vector<Tensor>& element,
77 bool& end_of_sequence) {
78 return iterator_->GetNext(&element, &end_of_sequence);
79 }
80
Cardinality() const81 int64 StandaloneTaskIterator::Cardinality() const {
82 return dataset_->Get()->Cardinality();
83 }
84
Create(const TaskDef & task_def,std::unique_ptr<TaskIterator> iterator,std::unique_ptr<TaskRunner> & out)85 Status TaskRunner::Create(const TaskDef& task_def,
86 std::unique_ptr<TaskIterator> iterator,
87 std::unique_ptr<TaskRunner>& out) {
88 if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) {
89 int64 cardinality = iterator->Cardinality();
90 if (cardinality != kInfiniteCardinality &&
91 cardinality != kUnknownCardinality) {
92 return errors::FailedPrecondition(
93 "Round robin reads require that the input dataset has infinite "
94 "cardinality, but the dataset has cardinality ",
95 cardinality,
96 ". Consider adding a `.repeat()` transformation to the dataset.");
97 }
98 out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
99 task_def.num_consumers());
100 } else {
101 out =
102 absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
103 }
104 return Status::OK();
105 }
106
FirstComeFirstServedTaskRunner(std::unique_ptr<TaskIterator> iterator)107 FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
108 std::unique_ptr<TaskIterator> iterator)
109 : iterator_(std::move(iterator)) {}
110
GetNext(const GetElementRequest & req,GetElementResponse & resp)111 Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
112 GetElementResponse& resp) {
113 std::vector<Tensor> element;
114 bool end_of_task;
115 resp.set_skip_task(false);
116 TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
117 resp.set_end_of_sequence(end_of_task);
118 if (!end_of_task) {
119 return MoveCompressedElement(std::move(element), resp);
120 }
121 return Status::OK();
122 }
123
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,int64 num_consumers)124 RoundRobinTaskRunner::RoundRobinTaskRunner(
125 std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
126 : num_consumers_(num_consumers),
127 buffer_(num_consumers_),
128 prefetch_thread_(std::move(iterator), num_consumers_) {
129 VLOG(1) << "Creating task runner for distributing data round-robin to "
130 << num_consumers << " consumers";
131 }
132
ValidateRequest(const GetElementRequest & req)133 Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
134 if (req.consumer_index() < 0 || req.round_index() < 0) {
135 return errors::FailedPrecondition(
136 "RoundRobinTaskRunner needs to know the consumer index and element "
137 "index of each request.");
138 }
139 if (req.consumer_index() >= num_consumers_) {
140 return errors::FailedPrecondition(
141 "Requesting data for consumer index ", req.consumer_index(),
142 ", but the task is configured for only ", num_consumers_, " consumers");
143 }
144 return Status::OK();
145 }
146
PrepareFullRound(int64 wait_us)147 Status RoundRobinTaskRunner::PrepareFullRound(int64 wait_us)
148 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
149 VLOG(1) << "Preparing full round for index " << current_round_;
150 // This was the last request to arrive, time to start a new round.
151 TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
152 round_skipped_ = buffer_.empty();
153 new_round_cv_.notify_all();
154 return Status::OK();
155 }
156
PreparePartialRound()157 Status RoundRobinTaskRunner::PreparePartialRound()
158 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
159 VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
160 << " consumers";
161 current_round_ = first_round_;
162 new_round_cv_.notify_all();
163 // Indicates that we need a partial round to get consumers back in sync.
164 auto next_round_request = *(requests_[first_round_ + 1].begin()->second);
165 if (next_round_request.skipped_previous_round()) {
166 VLOG(1) << "Skipping partial round";
167 round_skipped_ = true;
168 return Status::OK();
169 }
170 TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
171 round_skipped_ = false;
172 return Status::OK();
173 }
174
PrepareRound(const GetElementRequest & req)175 Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
176 mutex_lock l(mu_);
177 first_round_ = std::min(first_round_, req.round_index());
178 absl::flat_hash_map<int64, const GetElementRequest*>& round =
179 requests_[req.round_index()];
180 round[req.consumer_index()] = &req;
181 auto cleanup = gtl::MakeCleanup([&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
182 requests_[req.round_index()].erase(req.consumer_index());
183 });
184 if (current_round_ < req.round_index() && round.size() == num_consumers_) {
185 current_round_ = req.round_index();
186 int64 wait_us = kWaitBeforeSkipUs;
187 if (!req.allow_skip()) {
188 wait_us = -1;
189 }
190 TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
191 }
192 if (current_round_ < 0 &&
193 requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
194 num_consumers_) {
195 TF_RETURN_IF_ERROR(PreparePartialRound());
196 }
197 while (current_round_ < req.round_index()) {
198 TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
199 std::cv_status s =
200 new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
201 if (s == std::cv_status::timeout) {
202 // Clients will retry Unavailable.
203 return errors::Unavailable(
204 "Timeout waiting for other round-robin consumers to be ready.");
205 }
206 }
207 return prefetch_thread_.GetStatus();
208 }
209
GetNext(const GetElementRequest & req,GetElementResponse & resp)210 Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
211 GetElementResponse& resp) {
212 TF_RETURN_IF_ERROR(ValidateRequest(req));
213 resp.set_end_of_sequence(false);
214 VLOG(2) << "Received request from consumer index " << req.consumer_index()
215 << " for round " << req.round_index();
216 TF_RETURN_IF_ERROR(PrepareRound(req));
217 tf_shared_lock l(mu_);
218 resp.set_skip_task(round_skipped_);
219 if (round_skipped_) {
220 VLOG(1) << "Buffer not ready, skipping round " << current_round_
221 << " for consumer " << req.consumer_index();
222 return Status::OK();
223 }
224 std::vector<Tensor> element;
225 for (auto& component : buffer_[req.consumer_index()]) {
226 element.push_back(tensor::DeepCopy(component));
227 }
228 if (VLOG_IS_ON(2)) {
229 int64 size = 0;
230 for (auto& component : element) {
231 size += component.TotalBytes();
232 }
233 VLOG(2) << "Returning to consumer " << req.consumer_index() << " for round "
234 << req.round_index() << ". element size " << size;
235 }
236 return MoveCompressedElement(std::move(element), resp);
237 }
238
PrefetchThread(std::unique_ptr<TaskIterator> iterator,int64 round_size)239 PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
240 int64 round_size)
241 : iterator_(std::move(iterator)), round_size_(round_size) {
242 thread_ = absl::WrapUnique(
243 Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
244 }
245
~PrefetchThread()246 PrefetchThread::~PrefetchThread() {
247 mutex_lock l(mu_);
248 cancelled_ = true;
249 cv_.notify_all();
250 }
251
Run()252 void PrefetchThread::Run() {
253 while (true) {
254 {
255 mutex_lock l(mu_);
256 while (!cancelled_ && buffer_.size() >= round_size_) {
257 cv_.wait(l);
258 }
259 if (cancelled_) {
260 return;
261 }
262 }
263 std::vector<Tensor> element;
264 bool end_of_sequence;
265 Status s = iterator_->GetNext(element, end_of_sequence);
266 if (!s.ok()) {
267 mutex_lock l(mu_);
268 status_ = s;
269 cv_.notify_all();
270 return;
271 }
272 if (end_of_sequence) {
273 mutex_lock l(mu_);
274 status_ = errors::FailedPrecondition(
275 "Encountered end of sequence on a round-robin read iterator. "
276 "Please ensure that the dataset used for round-robin reading has "
277 "infinite cardinality, e.g. by adding a .repeat() transformation "
278 "at the end.");
279 cv_.notify_all();
280 return;
281 }
282 mutex_lock l(mu_);
283 buffer_.push_back(std::move(element));
284 cv_.notify_all();
285 }
286 }
287
FillBuffer(int64 wait_us,std::vector<std::vector<Tensor>> & out)288 Status PrefetchThread::FillBuffer(int64 wait_us,
289 std::vector<std::vector<Tensor>>& out) {
290 int64 start_us = Env::Default()->NowMicros();
291 out.clear();
292 mutex_lock l(mu_);
293 while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
294 int64 remaining_us = start_us + wait_us - Env::Default()->NowMicros();
295 if (wait_us >= 0 && remaining_us <= 0) {
296 break;
297 }
298 cv_.wait_for(l, std::chrono::microseconds(remaining_us));
299 }
300 TF_RETURN_IF_ERROR(status_);
301 if (cancelled_) {
302 return errors::Cancelled("Prefetch thread cancelled");
303 }
304 if (buffer_.size() < round_size_) {
305 DCHECK_GE(wait_us, 0);
306 return Status::OK();
307 }
308 for (auto& elem : buffer_) {
309 out.push_back(std::move(elem));
310 }
311 buffer_.clear();
312 cv_.notify_all();
313 return Status::OK();
314 }
315
GetStatus()316 Status PrefetchThread::GetStatus() {
317 mutex_lock l(mu_);
318 return status_;
319 }
320 } // namespace data
321 } // namespace tensorflow
322