1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/parallel_map_iterator.h"
16 
17 #include <atomic>
18 #include <deque>
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/stats_aggregator.h"
25 #include "tensorflow/core/kernels/data/stats_utils.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/cpu_info.h"
28 
29 namespace tensorflow {
30 namespace data {
31 namespace {
32 
33 class ParallelMapIterator : public DatasetBaseIterator {
34  public:
35   struct Params {
Paramstensorflow::data::__anonbcbac4e00111::ParallelMapIterator::Params36     Params(std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
37            int32 num_parallel_calls, bool sloppy, bool preserve_cardinality)
38         : parallel_map_functor(std::move(parallel_map_functor)),
39           num_parallel_calls(num_parallel_calls),
40           sloppy(sloppy),
41           preserve_cardinality(preserve_cardinality) {}
42 
43     std::unique_ptr<ParallelMapFunctor> parallel_map_functor;
44     int32 num_parallel_calls;
45     bool sloppy;
46     bool preserve_cardinality;
47   };
48 
ParallelMapIterator(const typename DatasetBaseIterator::BaseParams & base_params,const DatasetBase * input_dataset,Params params)49   ParallelMapIterator(
50       const typename DatasetBaseIterator::BaseParams& base_params,
51       const DatasetBase* input_dataset, Params params)
52       : DatasetBaseIterator(base_params),
53         input_dataset_(input_dataset),
54         parallel_map_functor_(std::move(params.parallel_map_functor)),
55         mu_(std::make_shared<mutex>()),
56         cond_var_(std::make_shared<condition_variable>()),
57         num_parallel_calls_(std::make_shared<model::SharedState>(
58             params.num_parallel_calls, mu_, cond_var_)),
59         sloppy_(params.sloppy),
60         preserve_cardinality_(params.preserve_cardinality) {
61     key_prefix_ = base_params.dataset->node_name();
62   }
63 
~ParallelMapIterator()64   ~ParallelMapIterator() override {
65     mutex_lock l(*mu_);
66     // Cancel the runner thread.
67     cancelled_ = true;
68     cond_var_->notify_all();
69     // Wait for all in-flight calls to complete.
70     while (num_calls_ > 0) {
71       cond_var_->wait(l);
72     }
73   }
74 
Initialize(IteratorContext * ctx)75   Status Initialize(IteratorContext* ctx) override {
76     mutex_lock l(*mu_);
77     if (num_parallel_calls_->value == model::kAutoTune) {
78       num_parallel_calls_->value = ctx->runner_threadpool_size();
79     }
80     TF_RETURN_IF_ERROR(
81         input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
82     return parallel_map_functor_->InitFunc(ctx);
83   }
84 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)85   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
86                          bool* end_of_sequence) override {
87     std::shared_ptr<InvocationResult> result;
88     {
89       mutex_lock l(*mu_);
90       EnsureRunnerThreadStarted(ctx);
91       while (ShouldWait(&result)) {
92         RecordStop(ctx);
93         cond_var_->wait(l);
94         RecordStart(ctx);
95       }
96     }
97     RecordStop(ctx);
98     result->notification.WaitForNotification();
99     RecordStart(ctx);
100     return ProcessResult(ctx, result, out_tensors, end_of_sequence);
101   }
102 
103  protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const104   std::shared_ptr<model::Node> CreateNode(
105       IteratorContext* ctx, model::Node::Args args) const override {
106     return model::MakeAsyncKnownRatioNode(
107         std::move(args),
108         /*ratio=*/1,
109         {model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1,
110                               /*max=*/ctx->runner_threadpool_size())});
111   }
112 
SaveInternal(IteratorStateWriter * writer)113   Status SaveInternal(IteratorStateWriter* writer) override {
114     mutex_lock l(*mu_);
115     // Wait for all in-flight calls to complete.
116     while (num_calls_ > 0) {
117       cond_var_->wait(l);
118     }
119     CHECK_EQ(num_calls_, 0);
120     TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
121     TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
122                                            invocation_results_.size()));
123     for (size_t i = 0; i < invocation_results_.size(); i++) {
124       const auto& result = *(invocation_results_[i]);
125       TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
126       TF_RETURN_IF_ERROR(writer->WriteScalar(
127           full_name(strings::StrCat("invocation_results[", i, "].size")),
128           result.return_values.size()));
129       for (size_t j = 0; j < result.return_values.size(); j++) {
130         TF_RETURN_IF_ERROR(writer->WriteTensor(
131             full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
132             result.return_values[j]));
133       }
134       if (result.end_of_input) {
135         TF_RETURN_IF_ERROR(
136             writer->WriteScalar(full_name(strings::StrCat("invocation_results[",
137                                                           i, "].end_of_input")),
138                                 ""));
139       }
140     }
141     return Status::OK();
142   }
143 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)144   Status RestoreInternal(IteratorContext* ctx,
145                          IteratorStateReader* reader) override {
146     mutex_lock l(*mu_);
147     TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
148     int64 invocation_results_size;
149     TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("invocation_results.size"),
150                                           &invocation_results_size));
151     for (size_t i = 0; i < invocation_results_size; i++) {
152       invocation_results_.push_back(std::make_shared<InvocationResult>());
153       auto& result = *invocation_results_.back();
154       TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
155       size_t num_return_values;
156       {
157         int64 size;
158         TF_RETURN_IF_ERROR(reader->ReadScalar(
159             full_name(strings::StrCat("invocation_results[", i, "].size")),
160             &size));
161         num_return_values = static_cast<size_t>(size);
162         if (num_return_values != size) {
163           return errors::InvalidArgument(strings::StrCat(
164               full_name(strings::StrCat("invocation_results[", i, "].size")),
165               ": ", size, " is not a valid value of type size_t."));
166         }
167       }
168       result.return_values.reserve(num_return_values);
169       for (size_t j = 0; j < num_return_values; j++) {
170         result.return_values.emplace_back();
171         TF_RETURN_IF_ERROR(reader->ReadTensor(
172             full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
173             &result.return_values.back()));
174       }
175       result.end_of_input = reader->Contains(full_name(
176           strings::StrCat("invocation_results[", i, "].end_of_input")));
177       result.notification.Notify();
178     }
179     return Status::OK();
180   }
181 
182  private:
183   struct InvocationResult {
184     Notification notification;
185     Status status;
186     std::vector<Tensor> return_values;
187     bool end_of_input;
188   };
189 
EnsureRunnerThreadStarted(IteratorContext * ctx)190   void EnsureRunnerThreadStarted(IteratorContext* ctx)
191       EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
192     if (!runner_thread_) {
193       auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
194       runner_thread_ = ctx->StartThread(
195           "tf_data_parallel_map",
196           std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy));
197     }
198   }
199 
CallCompleted(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<InvocationResult> & result)200   void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
201                      const std::shared_ptr<InvocationResult>& result)
202       LOCKS_EXCLUDED(*mu_) {
203     mutex_lock l(*mu_);
204     num_calls_--;
205     const auto& stats_aggregator = ctx->stats_aggregator();
206     if (stats_aggregator) {
207       stats_aggregator->AddScalar(
208           stats_utils::ThreadUtilizationScalarName(key_prefix_),
209           static_cast<float>(num_calls_) /
210               static_cast<float>(num_parallel_calls_->value));
211     }
212     RecordBufferEnqueue(ctx.get(), result->return_values);
213     result->notification.Notify();
214     cond_var_->notify_all();
215   }
216 
CallFunction(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<InvocationResult> & result)217   void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
218                     const std::shared_ptr<InvocationResult>& result)
219       LOCKS_EXCLUDED(*mu_) {
220     // Get the next input element.
221     std::vector<Tensor> input_element;
222     result->status =
223         input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
224     if (result->end_of_input || !result->status.ok()) {
225       CallCompleted(ctx, result);
226       return;
227     }
228 
229     auto done = [this, ctx, result](Status status) {
230       result->status.Update(status);
231       CallCompleted(ctx, result);
232     };
233 
234     // Apply the map function on `input_element`, storing the result in
235     // `result->return_values`, and invoking `done` when finished.
236     parallel_map_functor_->MapFunc(ctx.get(), prefix(),
237                                    std::move(input_element),
238                                    &result->return_values, std::move(done));
239   }
240 
ProcessResult(IteratorContext * ctx,const std::shared_ptr<InvocationResult> & result,std::vector<Tensor> * out_tensors,bool * end_of_sequence)241   Status ProcessResult(IteratorContext* ctx,
242                        const std::shared_ptr<InvocationResult>& result,
243                        std::vector<Tensor>* out_tensors, bool* end_of_sequence)
244       LOCKS_EXCLUDED(*mu_) {
245     if (!result->end_of_input && result->status.ok()) {
246       *out_tensors = std::move(result->return_values);
247       RecordBufferDequeue(ctx, *out_tensors);
248       *end_of_sequence = false;
249       return Status::OK();
250     }
251     if (errors::IsOutOfRange(result->status)) {
252       if (preserve_cardinality_) {
253         // To guarantee that the transformation preserves the cardinality of the
254         // dataset, we convert `OutOfRange` to `InvalidArgument` as the former
255         // may be interpreted by a caller as the end of sequence.
256         return errors::InvalidArgument(
257             "Function invocation produced OutOfRangeError: ",
258             result->status.error_message());
259       } else {
260         // `f` may deliberately raise `errors::OutOfRange` to indicate
261         // that we should terminate the iteration early.
262         *end_of_sequence = true;
263         return Status::OK();
264       }
265     }
266     *end_of_sequence = result->end_of_input;
267     return result->status;
268   }
269 
RunnerThread(const std::shared_ptr<IteratorContext> & ctx)270   void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
271       LOCKS_EXCLUDED(*mu_) {
272     RecordStart(ctx.get());
273     auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
274     std::vector<std::shared_ptr<InvocationResult>> new_calls;
275     {
276       tf_shared_lock l(*mu_);  // mu_ == num_parallel_calls_->mu
277       new_calls.reserve(num_parallel_calls_->value);
278     }
279     auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
280       int64 num_parallel_calls = num_parallel_calls_->value;
281       return num_calls_ >= num_parallel_calls ||
282              invocation_results_.size() >= num_parallel_calls;
283     };
284     while (true) {
285       {
286         mutex_lock l(*mu_);
287         while (!cancelled_ && busy()) {
288           RecordStop(ctx.get());
289           cond_var_->wait(l);
290           RecordStart(ctx.get());
291         }
292         if (cancelled_) {
293           return;
294         }
295         while (!busy()) {
296           invocation_results_.push_back(std::make_shared<InvocationResult>());
297           new_calls.push_back(invocation_results_.back());
298           num_calls_++;
299         }
300         const auto& stats_aggregator = ctx->stats_aggregator();
301         if (stats_aggregator) {
302           stats_aggregator->AddScalar(
303               stats_utils::ThreadUtilizationScalarName(key_prefix_),
304               static_cast<float>(num_calls_) /
305                   static_cast<float>(num_parallel_calls_->value));
306         }
307         cond_var_->notify_all();
308       }
309       for (const auto& call : new_calls) {
310         CallFunction(ctx, call);
311       }
312       new_calls.clear();
313     }
314   }
315 
316   // Determines whether the caller needs to wait for a result. Upon returning
317   // false, `result` will point to the result.
ShouldWait(std::shared_ptr<InvocationResult> * result)318   bool ShouldWait(std::shared_ptr<InvocationResult>* result)
319       EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
320     if (sloppy_) {
321       for (auto it = invocation_results_.begin();
322            it != invocation_results_.end(); ++it) {
323         if ((*it)->notification.HasBeenNotified() &&
324             (it == invocation_results_.begin() || !(*it)->end_of_input)) {
325           std::swap(*result, *it);
326           invocation_results_.erase(it);
327           cond_var_->notify_all();
328           return false;
329         }
330       }
331     } else if (!invocation_results_.empty()) {
332       std::swap(*result, invocation_results_.front());
333       invocation_results_.pop_front();
334       cond_var_->notify_all();
335       return false;
336     }
337     return true;
338   }
339 
WriteStatusLocked(IteratorStateWriter * writer,size_t index,const Status & status)340   Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
341                            const Status& status)
342       EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
343     TF_RETURN_IF_ERROR(
344         writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
345     if (!status.ok()) {
346       TF_RETURN_IF_ERROR(
347           writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
348     }
349     return Status::OK();
350   }
351 
ReadStatusLocked(IteratorStateReader * reader,size_t index,Status * status)352   Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
353                           Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
354     int64 code_int;
355     TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
356     error::Code code = static_cast<error::Code>(code_int);
357 
358     if (code != error::Code::OK) {
359       string error_message;
360       TF_RETURN_IF_ERROR(
361           reader->ReadScalar(ErrorMessageKey(index), &error_message));
362       *status = Status(code, error_message);
363     } else {
364       *status = Status::OK();
365     }
366     return Status::OK();
367   }
368 
CodeKey(size_t index)369   string CodeKey(size_t index) {
370     return full_name(strings::StrCat("invocation_results[", index, "].code"));
371   }
372 
ErrorMessageKey(size_t index)373   string ErrorMessageKey(size_t index) {
374     return full_name(
375         strings::StrCat("invocation_results[", index, "].error_message"));
376   }
377 
378   const DatasetBase* const input_dataset_;  // Not owned.
379   std::unique_ptr<ParallelMapFunctor> parallel_map_functor_;
380   // Used for coordination between the main thread and the runner thread.
381   const std::shared_ptr<mutex> mu_;
382   // Used for coordination between the main thread and the runner thread. In
383   // particular, the runner thread should only schedule new calls when the
384   // number of in-flight calls is less than the user specified level of
385   // parallelism and there are slots available in the `invocation_results_`
386   // buffer.
387   const std::shared_ptr<condition_variable> cond_var_;
388   // Identifies the maximum number of parallel calls.
389   const std::shared_ptr<model::SharedState> num_parallel_calls_;
390   // Determines whether outputs can be produced in non-deterministic order.
391   const bool sloppy_;
392   const bool preserve_cardinality_;
393   // Counts the number of outstanding calls.
394   int64 num_calls_ GUARDED_BY(*mu_) = 0;
395   std::unique_ptr<IteratorBase> input_impl_;
396   // Buffer for storing the invocation results.
397   std::deque<std::shared_ptr<InvocationResult>> invocation_results_
398       GUARDED_BY(*mu_);
399   std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
400   bool cancelled_ GUARDED_BY(*mu_) = false;
401   string key_prefix_;
402 };
403 
404 }  // namespace
405 
NewParallelMapIterator(const DatasetBaseIterator::BaseParams & params,const DatasetBase * input_dataset,std::unique_ptr<ParallelMapFunctor> parallel_map_functor,int32 num_parallel_calls,bool sloppy,bool preserve_cardinality)406 std::unique_ptr<IteratorBase> NewParallelMapIterator(
407     const DatasetBaseIterator::BaseParams& params,
408     const DatasetBase* input_dataset,
409     std::unique_ptr<ParallelMapFunctor> parallel_map_functor,
410     int32 num_parallel_calls, bool sloppy, bool preserve_cardinality) {
411   return absl::make_unique<ParallelMapIterator>(
412       params, input_dataset,
413       ParallelMapIterator::Params{std::move(parallel_map_functor),
414                                   num_parallel_calls, sloppy,
415                                   preserve_cardinality});
416 }
417 
418 }  // namespace data
419 }  // namespace tensorflow
420