1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/parallel_map_iterator.h"
16
17 #include <atomic>
18 #include <deque>
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/core/framework/stats_aggregator.h"
25 #include "tensorflow/core/kernels/data/stats_utils.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/cpu_info.h"
28
29 namespace tensorflow {
30 namespace data {
31 namespace {
32
33 class ParallelMapIterator : public DatasetBaseIterator {
34 public:
35 struct Params {
Paramstensorflow::data::__anonbcbac4e00111::ParallelMapIterator::Params36 Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
37 int32 num_parallel_calls, bool sloppy, bool preserve_cardinality)
38 : parallel_map_functor(std::move(parallel_map_functor)),
39 num_parallel_calls(num_parallel_calls),
40 sloppy(sloppy),
41 preserve_cardinality(preserve_cardinality) {}
42
43 std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
44 int32 num_parallel_calls;
45 bool sloppy;
46 bool preserve_cardinality;
47 };
48
ParallelMapIterator(const typename DatasetBaseIterator::BaseParams & base_params,const DatasetBase * input_dataset,Params params)49 ParallelMapIterator(
50 const typename DatasetBaseIterator::BaseParams& base_params,
51 const DatasetBase* input_dataset, Params params)
52 : DatasetBaseIterator(base_params),
53 input_dataset_(input_dataset),
54 parallel_map_functor_(std::move(params.parallel_map_functor)),
55 mu_(std::make_shared<mutex>()),
56 cond_var_(std::make_shared<condition_variable>()),
57 num_parallel_calls_(std::make_shared<model::SharedState>(
58 params.num_parallel_calls, mu_, cond_var_)),
59 sloppy_(params.sloppy),
60 preserve_cardinality_(params.preserve_cardinality) {
61 key_prefix_ = base_params.dataset->node_name();
62 }
63
~ParallelMapIterator()64 ~ParallelMapIterator() override {
65 mutex_lock l(*mu_);
66 // Cancel the runner thread.
67 cancelled_ = true;
68 cond_var_->notify_all();
69 // Wait for all in-flight calls to complete.
70 while (num_calls_ > 0) {
71 cond_var_->wait(l);
72 }
73 }
74
Initialize(IteratorContext * ctx)75 Status Initialize(IteratorContext* ctx) override {
76 mutex_lock l(*mu_);
77 if (num_parallel_calls_->value == model::kAutoTune) {
78 num_parallel_calls_->value = ctx->runner_threadpool_size();
79 }
80 TF_RETURN_IF_ERROR(
81 input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
82 return parallel_map_functor_->InitFunc(ctx);
83 }
84
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)85 Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
86 bool* end_of_sequence) override {
87 std::shared_ptr<InvocationResult> result;
88 {
89 mutex_lock l(*mu_);
90 EnsureRunnerThreadStarted(ctx);
91 while (ShouldWait(&result)) {
92 RecordStop(ctx);
93 cond_var_->wait(l);
94 RecordStart(ctx);
95 }
96 }
97 RecordStop(ctx);
98 result->notification.WaitForNotification();
99 RecordStart(ctx);
100 return ProcessResult(ctx, result, out_tensors, end_of_sequence);
101 }
102
103 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const104 std::shared_ptr<model::Node> CreateNode(
105 IteratorContext* ctx, model::Node::Args args) const override {
106 return model::MakeAsyncKnownRatioNode(
107 std::move(args),
108 /*ratio=*/1,
109 {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
110 /*max=*/ctx->runner_threadpool_size())});
111 }
112
SaveInternal(IteratorStateWriter * writer)113 Status SaveInternal(IteratorStateWriter* writer) override {
114 mutex_lock l(*mu_);
115 // Wait for all in-flight calls to complete.
116 while (num_calls_ > 0) {
117 cond_var_->wait(l);
118 }
119 CHECK_EQ(num_calls_, 0);
120 TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
121 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
122 invocation_results_.size()));
123 for (size_t i = 0; i < invocation_results_.size(); i++) {
124 const auto& result = *(invocation_results_[i]);
125 TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
126 TF_RETURN_IF_ERROR(writer->WriteScalar(
127 full_name(strings::StrCat("invocation_results[", i, "].size")),
128 result.return_values.size()));
129 for (size_t j = 0; j < result.return_values.size(); j++) {
130 TF_RETURN_IF_ERROR(writer->WriteTensor(
131 full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
132 result.return_values[j]));
133 }
134 if (result.end_of_input) {
135 TF_RETURN_IF_ERROR(
136 writer->WriteScalar(full_name(strings::StrCat("invocation_results[",
137 i, "].end_of_input")),
138 ""));
139 }
140 }
141 return Status::OK();
142 }
143
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)144 Status RestoreInternal(IteratorContext* ctx,
145 IteratorStateReader* reader) override {
146 mutex_lock l(*mu_);
147 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
148 int64 invocation_results_size;
149 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("invocation_results.size"),
150 &invocation_results_size));
151 for (size_t i = 0; i < invocation_results_size; i++) {
152 invocation_results_.push_back(std::make_shared<InvocationResult>());
153 auto& result = *invocation_results_.back();
154 TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
155 size_t num_return_values;
156 {
157 int64 size;
158 TF_RETURN_IF_ERROR(reader->ReadScalar(
159 full_name(strings::StrCat("invocation_results[", i, "].size")),
160 &size));
161 num_return_values = static_cast<size_t>(size);
162 if (num_return_values != size) {
163 return errors::InvalidArgument(strings::StrCat(
164 full_name(strings::StrCat("invocation_results[", i, "].size")),
165 ": ", size, " is not a valid value of type size_t."));
166 }
167 }
168 result.return_values.reserve(num_return_values);
169 for (size_t j = 0; j < num_return_values; j++) {
170 result.return_values.emplace_back();
171 TF_RETURN_IF_ERROR(reader->ReadTensor(
172 full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
173 &result.return_values.back()));
174 }
175 result.end_of_input = reader->Contains(full_name(
176 strings::StrCat("invocation_results[", i, "].end_of_input")));
177 result.notification.Notify();
178 }
179 return Status::OK();
180 }
181
182 private:
183 struct InvocationResult {
184 Notification notification;
185 Status status;
186 std::vector<Tensor> return_values;
187 bool end_of_input;
188 };
189
EnsureRunnerThreadStarted(IteratorContext * ctx)190 void EnsureRunnerThreadStarted(IteratorContext* ctx)
191 EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
192 if (!runner_thread_) {
193 auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
194 runner_thread_ = ctx->StartThread(
195 "tf_data_parallel_map",
196 std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
197 }
198 }
199
CallCompleted(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<InvocationResult> & result)200 void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
201 const std::shared_ptr<InvocationResult>& result)
202 LOCKS_EXCLUDED(*mu_) {
203 mutex_lock l(*mu_);
204 num_calls_--;
205 const auto& stats_aggregator = ctx->stats_aggregator();
206 if (stats_aggregator) {
207 stats_aggregator->AddScalar(
208 stats_utils::ThreadUtilizationScalarName(key_prefix_),
209 static_cast<float>(num_calls_) /
210 static_cast<float>(num_parallel_calls_->value));
211 }
212 RecordBufferEnqueue(ctx.get(), result->return_values);
213 result->notification.Notify();
214 cond_var_->notify_all();
215 }
216
CallFunction(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<InvocationResult> & result)217 void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
218 const std::shared_ptr<InvocationResult>& result)
219 LOCKS_EXCLUDED(*mu_) {
220 // Get the next input element.
221 std::vector<Tensor> input_element;
222 result->status =
223 input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
224 if (result->end_of_input || !result->status.ok()) {
225 CallCompleted(ctx, result);
226 return;
227 }
228
229 auto done = [this, ctx, result](Status status) {
230 result->status.Update(status);
231 CallCompleted(ctx, result);
232 };
233
234 // Apply the map function on `input_element`, storing the result in
235 // `result->return_values`, and invoking `done` when finished.
236 parallel_map_functor_->MapFunc(ctx.get(), prefix(),
237 std::move(input_element),
238 &result->return_values, std::move(done));
239 }
240
ProcessResult(IteratorContext * ctx,const std::shared_ptr<InvocationResult> & result,std::vector<Tensor> * out_tensors,bool * end_of_sequence)241 Status ProcessResult(IteratorContext* ctx,
242 const std::shared_ptr<InvocationResult>& result,
243 std::vector<Tensor>* out_tensors, bool* end_of_sequence)
244 LOCKS_EXCLUDED(*mu_) {
245 if (!result->end_of_input && result->status.ok()) {
246 *out_tensors = std::move(result->return_values);
247 RecordBufferDequeue(ctx, *out_tensors);
248 *end_of_sequence = false;
249 return Status::OK();
250 }
251 if (errors::IsOutOfRange(result->status)) {
252 if (preserve_cardinality_) {
253 // To guarantee that the transformation preserves the cardinality of the
254 // dataset, we convert `OutOfRange` to `InvalidArgument` as the former
255 // may be interpreted by a caller as the end of sequence.
256 return errors::InvalidArgument(
257 "Function invocation produced OutOfRangeError: ",
258 result->status.error_message());
259 } else {
260 // `f` may deliberately raise `errors::OutOfRange` to indicate
261 // that we should terminate the iteration early.
262 *end_of_sequence = true;
263 return Status::OK();
264 }
265 }
266 *end_of_sequence = result->end_of_input;
267 return result->status;
268 }
269
RunnerThread(const std::shared_ptr<IteratorContext> & ctx)270 void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
271 LOCKS_EXCLUDED(*mu_) {
272 RecordStart(ctx.get());
273 auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
274 std::vector<std::shared_ptr<InvocationResult>> new_calls;
275 {
276 tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
277 new_calls.reserve(num_parallel_calls_->value);
278 }
279 auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
280 int64 num_parallel_calls = num_parallel_calls_->value;
281 return num_calls_ >= num_parallel_calls ||
282 invocation_results_.size() >= num_parallel_calls;
283 };
284 while (true) {
285 {
286 mutex_lock l(*mu_);
287 while (!cancelled_ && busy()) {
288 RecordStop(ctx.get());
289 cond_var_->wait(l);
290 RecordStart(ctx.get());
291 }
292 if (cancelled_) {
293 return;
294 }
295 while (!busy()) {
296 invocation_results_.push_back(std::make_shared<InvocationResult>());
297 new_calls.push_back(invocation_results_.back());
298 num_calls_++;
299 }
300 const auto& stats_aggregator = ctx->stats_aggregator();
301 if (stats_aggregator) {
302 stats_aggregator->AddScalar(
303 stats_utils::ThreadUtilizationScalarName(key_prefix_),
304 static_cast<float>(num_calls_) /
305 static_cast<float>(num_parallel_calls_->value));
306 }
307 cond_var_->notify_all();
308 }
309 for (const auto& call : new_calls) {
310 CallFunction(ctx, call);
311 }
312 new_calls.clear();
313 }
314 }
315
316 // Determines whether the caller needs to wait for a result. Upon returning
317 // false, `result` will point to the result.
ShouldWait(std::shared_ptr<InvocationResult> * result)318 bool ShouldWait(std::shared_ptr<InvocationResult>* result)
319 EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
320 if (sloppy_) {
321 for (auto it = invocation_results_.begin();
322 it != invocation_results_.end(); ++it) {
323 if ((*it)->notification.HasBeenNotified() &&
324 (it == invocation_results_.begin() || !(*it)->end_of_input)) {
325 std::swap(*result, *it);
326 invocation_results_.erase(it);
327 cond_var_->notify_all();
328 return false;
329 }
330 }
331 } else if (!invocation_results_.empty()) {
332 std::swap(*result, invocation_results_.front());
333 invocation_results_.pop_front();
334 cond_var_->notify_all();
335 return false;
336 }
337 return true;
338 }
339
WriteStatusLocked(IteratorStateWriter * writer,size_t index,const Status & status)340 Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
341 const Status& status)
342 EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
343 TF_RETURN_IF_ERROR(
344 writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
345 if (!status.ok()) {
346 TF_RETURN_IF_ERROR(
347 writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
348 }
349 return Status::OK();
350 }
351
ReadStatusLocked(IteratorStateReader * reader,size_t index,Status * status)352 Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
353 Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
354 int64 code_int;
355 TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
356 error::Code code = static_cast<error::Code>(code_int);
357
358 if (code != error::Code::OK) {
359 string error_message;
360 TF_RETURN_IF_ERROR(
361 reader->ReadScalar(ErrorMessageKey(index), &error_message));
362 *status = Status(code, error_message);
363 } else {
364 *status = Status::OK();
365 }
366 return Status::OK();
367 }
368
CodeKey(size_t index)369 string CodeKey(size_t index) {
370 return full_name(strings::StrCat("invocation_results[", index, "].code"));
371 }
372
ErrorMessageKey(size_t index)373 string ErrorMessageKey(size_t index) {
374 return full_name(
375 strings::StrCat("invocation_results[", index, "].error_message"));
376 }
377
378 const DatasetBase* const input_dataset_; // Not owned.
379 std::unique_ptr<ParallelMapFunctor> parallel_map_functor_;
380 // Used for coordination between the main thread and the runner thread.
381 const std::shared_ptr<mutex> mu_;
382 // Used for coordination between the main thread and the runner thread. In
383 // particular, the runner thread should only schedule new calls when the
384 // number of in-flight calls is less than the user specified level of
385 // parallelism and there are slots available in the `invocation_results_`
386 // buffer.
387 const std::shared_ptr<condition_variable> cond_var_;
388 // Identifies the maximum number of parallel calls.
389 const std::shared_ptr<model::SharedState> num_parallel_calls_;
390 // Determines whether outputs can be produced in non-deterministic order.
391 const bool sloppy_;
392 const bool preserve_cardinality_;
393 // Counts the number of outstanding calls.
394 int64 num_calls_ GUARDED_BY(*mu_) = 0;
395 std::unique_ptr<IteratorBase> input_impl_;
396 // Buffer for storing the invocation results.
397 std::deque<std::shared_ptr<InvocationResult>> invocation_results_
398 GUARDED_BY(*mu_);
399 std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
400 bool cancelled_ GUARDED_BY(*mu_) = false;
401 string key_prefix_;
402 };
403
404 } // namespace
405
NewParallelMapIterator(const DatasetBaseIterator::BaseParams & params,const DatasetBase * input_dataset,std::unique_ptr<ParallelMapFunctor> parallel_map_functor,int32 num_parallel_calls,bool sloppy,bool preserve_cardinality)406 std::unique_ptr<IteratorBase> NewParallelMapIterator(
407 const DatasetBaseIterator::BaseParams& params,
408 const DatasetBase* input_dataset,
409 std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
410 int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) {
411 return absl::make_unique<ParallelMapIterator>(
412 params, input_dataset,
413 ParallelMapIterator::Params{std::move(parallel_map_functor),
414 num_parallel_calls, sloppy,
415 preserve_cardinality});
416 }
417
418 } // namespace data
419 } // namespace tensorflow
420