1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
18 
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <random>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/periodic_function.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/core/threadpool.h"
33 #include "tensorflow/core/platform/byte_order.h"
34 #include "tensorflow/core/platform/cpu_info.h"
35 #include "tensorflow/core/platform/env.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/platform/threadpool_interface.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41 
42 namespace tensorflow {
43 namespace serving {
44 namespace internal {
45 template <typename TaskType>
46 class ASBSBatch;
47 
48 template <typename TaskType>
49 class ASBSQueue;
50 }  // namespace internal
51 
52 // Shared batch scheduler designed to minimize latency. The scheduler keeps
53 // track of a number of queues (one per model or model version) which are
54 // continuously enqueuing requests. The scheduler groups the requests into
55 // batches which it periodically sends off for processing (see
56 // shared_batch_scheduler.h for more details). AdaptiveSharedBatchScheduler
57 // (ASBS) prioritizes batches primarily by age (i.e. the batch's oldest request)
58 // along with a configurable preference for scheduling larger batches first.
59 //
60 //
61 // ASBS tries to keep the system busy by maintaining an adjustable number of
62 // concurrently processed batches.  If a new batch is created, and the number of
63 // in flight batches is below the target, the next (i.e. oldest) batch is
64 // immediately scheduled.  Similarly, when a batch finishes processing, the
65 // target is rechecked, and another batch may be scheduled.  To avoid the need
66 // to carefully tune the target for workload, model type, platform, etc, it is
67 // dynamically adjusted in order to provide the lowest average latency.
68 //
69 // Some potential use cases:
70 // Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
71 //   involves serial processing by a device, from a latency perspective it is
72 //   desirable to keep the device evenly loaded, avoiding the need to wait for
73 //   the device to process prior batches.
74 // CPU utilization - If the batch processing is cpu dominated, you can reap
75 //   latency gains when underutilized by increasing the processing rate, but
76 //   back the rate off when the load increases to avoid overload.
77 
78 template <typename TaskType>
79 class AdaptiveSharedBatchScheduler
80     : public std::enable_shared_from_this<
81           AdaptiveSharedBatchScheduler<TaskType>> {
82  public:
~AdaptiveSharedBatchScheduler()83   ~AdaptiveSharedBatchScheduler() {
84     // Finish processing batches before destroying other class members.
85     batch_thread_pool_.reset();
86   }
87 
88   struct Options {
89     // The name to use for the pool of batch threads.
90     string thread_pool_name = {"batch_threads"};
91     // Number of batch processing threads - the maximum value of
92     // in_flight_batches_limit_.  It is recommended that this value be set by
93     // running the system under load, observing the learned value for
94     // in_flight_batches_limit_, and setting this maximum to ~ 2x the value.
95     // Under low load, in_flight_batches_limit_ has no substantial effect on
96     // latency and therefore undergoes a random walk.  Unreasonably large values
97     // for num_batch_threads allows for large in_flight_batches_limit_, which
98     // will harm latency for some time once load increases again.
99     int64 num_batch_threads = port::MaxParallelism();
100     // You can pass a ThreadPoolInterface directly rather than the above two
101     // parameters.  If given, the above two parameers are ignored.  Ownership of
102     // the threadpool is not transferred.
103     thread::ThreadPoolInterface* thread_pool = nullptr;
104     // Lower bound for in_flight_batches_limit_. As discussed above, can be used
105     // to minimize the damage caused by the random walk under low load.
106     int64 min_in_flight_batches_limit = 1;
107     // Although batch selection is primarily based on age, this parameter
108     // specifies a preference for larger batches.  A full batch will be
109     // scheduled before an older, nearly empty batch as long as the age gap is
110     // less than full_batch_scheduling_boost_micros.  The optimal value for this
111     // parameter should be of order the batch processing latency, but must be
112     // chosen carefully, as too large a value will harm tail latency.
113     int64 full_batch_scheduling_boost_micros = 0;
114     // The environment to use (typically only overridden by test code).
115     Env* env = Env::Default();
116     // Initial limit for number of batches being concurrently processed.
117     // Non-integer values correspond to probabilistic limits - i.e. a value of
118     // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time.
119     double initial_in_flight_batches_limit = 3;
120     // Number of batches between adjustments of in_flight_batches_limit.  Larger
121     // numbers will give less noisy latency measurements, but will be less
122     // responsive to changes in workload.
123     int64 batches_to_average_over = 1000;
124   };
125 
126   // Ownership is shared between the caller of Create() and any queues created
127   // via AddQueue().
128   static Status Create(
129       const Options& options,
130       std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
131 
132   struct QueueOptions {
133     // Maximum size of a batch that's formed within
134     // `ASBSQueue<TaskType>::Schedule`.
135     int max_batch_size = 1000;
136     // Maximum size of input task, which is submitted to the queue by
137     // calling `ASBSQueue<TaskType>::Schedule` and used to form batches.
138     //
139     // If specified, it should be larger than or equal to 'max_batch_size'.
140     absl::optional<int> max_input_task_size = absl::nullopt;
141     // Maximum number of enqueued (i.e. non-scheduled) batches.
142     int max_enqueued_batches = 10;
143     // Amount of time non-full batches must wait before becoming schedulable.
144     // A non-zero value can improve performance by limiting the scheduling of
145     // nearly empty batches.
146     int64 batch_timeout_micros = 0;
147     // If non nullptr, split_input_task_func should split input_task into
148     // multiple tasks, the first of which has size first_size and the remaining
149     // not exceeding max_size. This function may acquire ownership of input_task
150     // and should return a status indicating if the split was successful. Upon
151     // success, the caller can assume that all output_tasks will be scheduled.
152     // Including this option allows the scheduler to pack batches better and
153     // should usually improve overall throughput.
154     std::function<Status(std::unique_ptr<TaskType>* input_task, int first_size,
155                          int max_batch_size,
156                          std::vector<std::unique_ptr<TaskType>>* output_tasks)>
157         split_input_task_func;
158   };
159 
160   using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
161 
162   // Adds queue (and its callback) to be managed by this scheduler.
163   Status AddQueue(const QueueOptions& options,
164                   BatchProcessor process_batch_callback,
165                   std::unique_ptr<BatchScheduler<TaskType>>* queue);
166 
in_flight_batches_limit()167   double in_flight_batches_limit() {
168     mutex_lock l(mu_);
169     return in_flight_batches_limit_;
170   }
171 
172  private:
173   // access to AddBatch, MaybeScheduleClosedBatches, RemoveQueue, GetEnv.
174   friend class internal::ASBSQueue<TaskType>;
175 
176   explicit AdaptiveSharedBatchScheduler(const Options& options);
177 
178   // Tracks processing latency and adjusts in_flight_batches_limit to minimize.
179   void CallbackWrapper(const internal::ASBSBatch<TaskType>* batch,
180                        BatchProcessor callback, bool is_express);
181 
182   // Schedules batch if in_flight_batches_limit_ is not met.
183   void MaybeScheduleNextBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
184 
185   // Schedules all closed batches in batches_ for which an idle thread is
186   // available in batch_thread_pool_.
187   // Batches scheduled this way are called express batches.
188   // Express batches are not limited by in_flight_batches_limit_, and
189   // their latencies will not affect in_flight_batches_limit_.
190   void MaybeScheduleClosedBatches();
191 
192   void MaybeScheduleClosedBatchesLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
193 
194   // Notifies scheduler of non-empty batch which is eligible for processing.
195   void AddBatch(const internal::ASBSBatch<TaskType>* batch);
196 
197   // Removes queue from scheduler.
198   void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
199 
GetEnv()200   Env* GetEnv() const { return options_.env; }
201 
202   const Options options_;
203 
204   // Collection of batches added by AddBatch, ordered by age. Owned by scheduler
205   // until they are released for processing.
206   std::vector<const internal::ASBSBatch<TaskType>*> batches_ TF_GUARDED_BY(mu_);
207 
208   // Unowned queues and callbacks added by AddQueue.
209   std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
210       queues_and_callbacks_ TF_GUARDED_BY(mu_);
211 
212   mutex mu_;
213 
214   // Responsible for running the batch processing callbacks.
215   std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
216 
217   // Limit on number of batches which can be concurrently processed.
218   // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2
219   // results in an actual cap of 3 80% of the time, and 4 20% of the time.
220   double in_flight_batches_limit_ TF_GUARDED_BY(mu_);
221 
222   // Number of regular batches currently being processed.
223   int64 in_flight_batches_ TF_GUARDED_BY(mu_) = 0;
224   // Number of express batches currently being processed.
225   int64 in_flight_express_batches_ TF_GUARDED_BY(mu_) = 0;
226 
227   // RNG engine and distribution.
228   std::default_random_engine rand_engine_;
229   std::uniform_real_distribution<double> rand_double_;
230 
231   // Fields controlling the dynamic adjustment of in_flight_batches_limit_.
232   // Number of batches since the last in_flight_batches_limit_ adjustment.
233   int64 batch_count_ TF_GUARDED_BY(mu_) = 0;
234   // Sum of processing latency for batches counted by batch_count_.
235   int64 batch_latency_sum_ TF_GUARDED_BY(mu_) = 0;
236   // Average batch latency for previous value of in_flight_batches_limit_.
237   double last_avg_latency_ms_ TF_GUARDED_BY(mu_) = 0;
238   // Did last_avg_latency_ms_ decrease from the previous last_avg_latency_ms_?
239   bool last_latency_decreased_ TF_GUARDED_BY(mu_) = false;
240   // Current direction (+-) to adjust in_flight_batches_limit_
241   int step_direction_ TF_GUARDED_BY(mu_) = 1;
242   // Max adjustment size (as a fraction of in_flight_batches_limit_).
243   constexpr static double kMaxStepSizeMultiplier = 0.125;  // 1/8;
244   // Min adjustment size (as a fraction of in_flight_batches_limit_).
245   constexpr static double kMinStepSizeMultiplier = 0.0078125;  // 1/128
246   // Current adjustment size (as a fraction of in_flight_batches_limit_).
247   double step_size_multiplier_ TF_GUARDED_BY(mu_) = kMaxStepSizeMultiplier;
248 
249   TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
250 };
251 
252 //////////////////////////////////////////////////////////
253 // Implementation details follow. API users need not read.
254 
255 namespace internal {
256 // Consolidates tasks into batches, passing them off to the
257 // AdaptiveSharedBatchScheduler for processing.
258 template <typename TaskType>
259 class ASBSQueue : public BatchScheduler<TaskType> {
260  public:
261   using QueueOptions =
262       typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
263 
264   ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
265             const QueueOptions& options);
266 
267   ~ASBSQueue() override;
268 
269   // Adds task to current batch. Fails if the task size is larger than the batch
270   // size or if the current batch is full and this queue's number of outstanding
271   // batches is at its maximum.
272   Status Schedule(std::unique_ptr<TaskType>* task) override;
273 
274   // Number of tasks waiting to be scheduled.
275   size_t NumEnqueuedTasks() const override;
276 
277   // Number of size 1 tasks which could currently be scheduled without failing.
278   size_t SchedulingCapacity() const override;
279 
280   // Notifies queue that a batch is about to be scheduled; the queue should not
281   // place any more tasks in this batch.
282   void ReleaseBatch(const ASBSBatch<TaskType>* batch);
283 
max_task_size()284   size_t max_task_size() const override { return options_.max_batch_size; }
285 
286  private:
287   // Number of size 1 tasks which could currently be scheduled without failing.
288   size_t SchedulingCapacityLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
289 
290   // Returns uint64 one greater than was returned by the previous call.
291   // Context id is reused after std::numeric_limits<uint64>::max is exhausted.
292   static uint64 NewTraceMeContextIdForBatch();
293 
294   std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
295   const QueueOptions options_;
296   // Owned by scheduler_.
297   ASBSBatch<TaskType>* current_batch_ TF_GUARDED_BY(mu_) = nullptr;
298   int64 num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0;
299   int64 num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0;
300   mutable mutex mu_;
301   TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
302 };
303 
304 // Batch which remembers when and by whom it was created.
305 template <typename TaskType>
306 class ASBSBatch : public Batch<TaskType> {
307  public:
ASBSBatch(ASBSQueue<TaskType> * queue,int64 creation_time_micros,int64 batch_timeout_micros,uint64 traceme_context_id)308   ASBSBatch(ASBSQueue<TaskType>* queue, int64 creation_time_micros,
309             int64 batch_timeout_micros, uint64 traceme_context_id)
310       : queue_(queue),
311         creation_time_micros_(creation_time_micros),
312         schedulable_time_micros_(creation_time_micros + batch_timeout_micros),
313         traceme_context_id_(traceme_context_id) {}
314 
~ASBSBatch()315   ~ASBSBatch() override {}
316 
queue()317   ASBSQueue<TaskType>* queue() const { return queue_; }
318 
creation_time_micros()319   int64 creation_time_micros() const { return creation_time_micros_; }
320 
schedulable_time_micros()321   int64 schedulable_time_micros() const { return schedulable_time_micros_; }
322 
traceme_context_id()323   uint64 traceme_context_id() const { return traceme_context_id_; }
324 
325  private:
326   ASBSQueue<TaskType>* queue_;
327   const int64 creation_time_micros_;
328   const int64 schedulable_time_micros_;
329   const uint64 traceme_context_id_;
330   TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
331 };
332 }  // namespace internal
333 
334 // ---------------- AdaptiveSharedBatchScheduler ----------------
335 
336 template <typename TaskType>
337 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMaxStepSizeMultiplier;
338 
339 template <typename TaskType>
340 constexpr double AdaptiveSharedBatchScheduler<TaskType>::kMinStepSizeMultiplier;
341 
342 template <typename TaskType>
Create(const Options & options,std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> * scheduler)343 Status AdaptiveSharedBatchScheduler<TaskType>::Create(
344     const Options& options,
345     std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
346   if (options.num_batch_threads < 1) {
347     return errors::InvalidArgument("num_batch_threads must be positive; was ",
348                                    options.num_batch_threads);
349   }
350   if (options.min_in_flight_batches_limit < 1) {
351     return errors::InvalidArgument(
352         "min_in_flight_batches_limit must be >= 1; was ",
353         options.min_in_flight_batches_limit);
354   }
355   if (options.min_in_flight_batches_limit > options.num_batch_threads) {
356     return errors::InvalidArgument(
357         "min_in_flight_batches_limit (", options.min_in_flight_batches_limit,
358         ") must be <= num_batch_threads (", options.num_batch_threads, ")");
359   }
360   if (options.full_batch_scheduling_boost_micros < 0) {
361     return errors::InvalidArgument(
362         "full_batch_scheduling_boost_micros can't be negative; was ",
363         options.full_batch_scheduling_boost_micros);
364   }
365   if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
366     return errors::InvalidArgument(
367         "initial_in_flight_batches_limit (",
368         options.initial_in_flight_batches_limit,
369         ") should not be larger than num_batch_threads (",
370         options.num_batch_threads, ")");
371   }
372   if (options.initial_in_flight_batches_limit <
373       options.min_in_flight_batches_limit) {
374     return errors::InvalidArgument("initial_in_flight_batches_limit (",
375                                    options.initial_in_flight_batches_limit,
376                                    "must be >= min_in_flight_batches_limit (",
377                                    options.min_in_flight_batches_limit, ")");
378   }
379   if (options.batches_to_average_over < 1) {
380     return errors::InvalidArgument(
381         "batches_to_average_over should be "
382         "greater than or equal to 1; was ",
383         options.batches_to_average_over);
384   }
385   scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
386   return Status::OK();
387 }
388 
389 template <typename TaskType>
AdaptiveSharedBatchScheduler(const Options & options)390 AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
391     const Options& options)
392     : options_(options),
393       in_flight_batches_limit_(options.initial_in_flight_batches_limit),
394       rand_double_(0.0, 1.0) {
395   std::random_device device;
396   rand_engine_.seed(device());
397   if (options.thread_pool == nullptr) {
398     batch_thread_pool_.reset(new thread::ThreadPool(
399         GetEnv(), options.thread_pool_name, options.num_batch_threads));
400   } else {
401     batch_thread_pool_.reset(new thread::ThreadPool(options.thread_pool));
402   }
403 }
404 
405 template <typename TaskType>
AddQueue(const QueueOptions & options,BatchProcessor process_batch_callback,std::unique_ptr<BatchScheduler<TaskType>> * queue)406 Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
407     const QueueOptions& options, BatchProcessor process_batch_callback,
408     std::unique_ptr<BatchScheduler<TaskType>>* queue) {
409   if (options.max_batch_size <= 0) {
410     return errors::InvalidArgument("max_batch_size must be positive; was ",
411                                    options.max_batch_size);
412   }
413   if (options.max_enqueued_batches <= 0) {
414     return errors::InvalidArgument(
415         "max_enqueued_batches must be positive; was ",
416         options.max_enqueued_batches);
417   }
418   if (options.max_input_task_size.has_value()) {
419     if (options.max_input_task_size.value() < options.max_batch_size) {
420       return errors::InvalidArgument(
421           "max_input_task_size must be larger than or equal to max_batch_size;"
422           "got max_input_task_size as ",
423           options.max_input_task_size.value(), " and max_batch_size as ",
424           options.max_batch_size);
425     }
426   }
427   internal::ASBSQueue<TaskType>* asbs_queue_raw;
428   queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
429                    this->shared_from_this(), options));
430   mutex_lock l(mu_);
431   queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
432   return Status::OK();
433 }
434 
435 template <typename TaskType>
AddBatch(const internal::ASBSBatch<TaskType> * batch)436 void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
437     const internal::ASBSBatch<TaskType>* batch) {
438   mutex_lock l(mu_);
439   batches_.push_back(batch);
440   int64 delay_micros = batch->schedulable_time_micros() - GetEnv()->NowMicros();
441   if (delay_micros <= 0) {
442     MaybeScheduleNextBatch();
443     return;
444   }
445   // Try to schedule batch once it becomes schedulable. Although scheduler waits
446   // for all batches to finish processing before allowing itself to be deleted,
447   // MaybeScheduleNextBatch() is called in other places, and therefore it's
448   // possible the scheduler could be deleted by the time this closure runs.
449   // Grab a shared_ptr reference to prevent this from happening.
450   GetEnv()->SchedClosureAfter(
451       delay_micros, [this, lifetime_preserver = this->shared_from_this()] {
452         mutex_lock l(mu_);
453         MaybeScheduleNextBatch();
454       });
455 }
456 
457 template <typename TaskType>
RemoveQueue(const internal::ASBSQueue<TaskType> * queue)458 void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
459     const internal::ASBSQueue<TaskType>* queue) {
460   mutex_lock l(mu_);
461   queues_and_callbacks_.erase(queue);
462 }
463 
464 template <typename TaskType>
MaybeScheduleNextBatch()465 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatch() {
466   if (batches_.empty() || in_flight_batches_ >= in_flight_batches_limit_)
467     return;
468   // Non-integer limit handled probabilistically.
469   if (in_flight_batches_limit_ - in_flight_batches_ < 1 &&
470       rand_double_(rand_engine_) >
471           in_flight_batches_limit_ - in_flight_batches_) {
472     return;
473   }
474   auto best_it = batches_.end();
475   double best_score = (std::numeric_limits<double>::max)();
476   int64 now_micros = GetEnv()->NowMicros();
477   for (auto it = batches_.begin(); it != batches_.end(); it++) {
478     if ((*it)->schedulable_time_micros() > now_micros) continue;
479     const double score =
480         (*it)->creation_time_micros() -
481         options_.full_batch_scheduling_boost_micros * (*it)->size() /
482             static_cast<double>((*it)->queue()->max_task_size());
483     if (best_it == batches_.end() || score < best_score) {
484       best_score = score;
485       best_it = it;
486     }
487   }
488   // No schedulable batches.
489   if (best_it == batches_.end()) return;
490   const internal::ASBSBatch<TaskType>* batch = *best_it;
491   batches_.erase(best_it);
492   // Queue may destroy itself after ReleaseBatch is called.
493   batch->queue()->ReleaseBatch(batch);
494   batch_thread_pool_->Schedule(
495       std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this,
496                 batch, queues_and_callbacks_[batch->queue()], false));
497   in_flight_batches_++;
498 }
499 
500 template <typename TaskType>
MaybeScheduleClosedBatches()501 void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleClosedBatches() {
502   mutex_lock l(mu_);
503   MaybeScheduleClosedBatchesLocked();
504 }
505 
506 template <typename TaskType>
507 void AdaptiveSharedBatchScheduler<
MaybeScheduleClosedBatchesLocked()508     TaskType>::MaybeScheduleClosedBatchesLocked() {
509   // Only schedule closed batches if we have spare capacity.
510   int available_threads =
511       static_cast<int>(options_.num_batch_threads - in_flight_batches_ -
512                        in_flight_express_batches_);
513   for (auto it = batches_.begin();
514        it != batches_.end() && available_threads > 0;) {
515     if ((*it)->IsClosed()) {
516       const internal::ASBSBatch<TaskType>* batch = *it;
517       it = batches_.erase(it);
518       batch->queue()->ReleaseBatch(batch);
519       batch_thread_pool_->Schedule(
520           std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
521                     this, batch, queues_and_callbacks_[batch->queue()], true));
522       in_flight_express_batches_++;
523       available_threads--;
524     } else {
525       ++it;
526     }
527   }
528 }
529 
530 template <typename TaskType>
CallbackWrapper(const internal::ASBSBatch<TaskType> * batch,AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,bool is_express)531 void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper(
532     const internal::ASBSBatch<TaskType>* batch,
533     AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,
534     bool is_express) {
535   profiler::TraceMeConsumer trace_me(
536       [&] {
537         return profiler::TraceMeEncode(
538             "ProcessBatch", {{"batch_size_before_padding", batch->size()}});
539       },
540       profiler::ContextType::kAdaptiveSharedBatchScheduler,
541       batch->traceme_context_id());
542   int64 start_time = batch->creation_time_micros();
543   callback(std::unique_ptr<Batch<TaskType>>(
544       const_cast<internal::ASBSBatch<TaskType>*>(batch)));
545   int64 end_time = GetEnv()->NowMicros();
546   mutex_lock l(mu_);
547   if (is_express) {
548     in_flight_express_batches_--;
549     MaybeScheduleClosedBatchesLocked();
550     return;
551   }
552   in_flight_batches_--;
553   batch_count_++;
554   batch_latency_sum_ += end_time - start_time;
555   // Occasionally adjust in_flight_batches_limit_ to minimize average latency.
556   // Although the optimal value may depend on the workload, the latency should
557   // be a simple convex function of in_flight_batches_limit_, allowing us to
558   // locate the global minimum relatively quickly.
559   if (batch_count_ == options_.batches_to_average_over) {
560     double current_avg_latency_ms = (batch_latency_sum_ / 1000.) / batch_count_;
561     bool current_latency_decreased =
562         current_avg_latency_ms < last_avg_latency_ms_;
563     if (current_latency_decreased) {
564       // If latency improvement was because we're moving in the correct
565       // direction, increase step_size so that we can get to the minimum faster.
566       // If latency improvement was due to backtracking from a previous failure,
567       // decrease step_size in order to refine our location.
568       step_size_multiplier_ *= (last_latency_decreased_ ? 2 : 0.5);
569       step_size_multiplier_ =
570           std::min(step_size_multiplier_, kMaxStepSizeMultiplier);
571       step_size_multiplier_ =
572           std::max(step_size_multiplier_, kMinStepSizeMultiplier);
573     } else {
574       // Return (nearly) to previous position and confirm that latency is better
575       // there before decreasing step size.
576       step_direction_ = -step_direction_;
577     }
578     in_flight_batches_limit_ +=
579         step_direction_ * in_flight_batches_limit_ * step_size_multiplier_;
580     in_flight_batches_limit_ =
581         std::min(in_flight_batches_limit_,
582                  static_cast<double>(options_.num_batch_threads));
583     in_flight_batches_limit_ =
584         std::max(in_flight_batches_limit_,
585                  static_cast<double>(options_.min_in_flight_batches_limit));
586     last_avg_latency_ms_ = current_avg_latency_ms;
587     last_latency_decreased_ = current_latency_decreased;
588     batch_count_ = 0;
589     batch_latency_sum_ = 0;
590   }
591   MaybeScheduleNextBatch();
592 }
593 
594 // ---------------- ASBSQueue ----------------
595 
596 namespace internal {
597 template <typename TaskType>
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,const QueueOptions & options)598 ASBSQueue<TaskType>::ASBSQueue(
599     std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
600     const QueueOptions& options)
601     : scheduler_(scheduler), options_(options) {}
602 
603 template <typename TaskType>
~ASBSQueue()604 ASBSQueue<TaskType>::~ASBSQueue() {
605   // Wait until last batch has been scheduled.
606   const int kSleepMicros = 1000;
607   for (;;) {
608     {
609       mutex_lock l(mu_);
610       if (num_enqueued_batches_ == 0) {
611         break;
612       }
613     }
614     scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
615   }
616   scheduler_->RemoveQueue(this);
617 }
618 
619 template <typename TaskType>
Schedule(std::unique_ptr<TaskType> * task)620 Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
621   size_t size = (*task)->size();
622   if (options_.split_input_task_func == nullptr &&
623       size > options_.max_batch_size) {
624     return errors::InvalidArgument("Task size ", size,
625                                    " is larger than maximum batch size ",
626                                    options_.max_batch_size);
627   }
628   if (options_.max_input_task_size.has_value() &&
629       (size > options_.max_input_task_size.value())) {
630     return errors::InvalidArgument("Task size ", size,
631                                    " is larger than max input task size ",
632                                    options_.max_input_task_size.value());
633   }
634 
635   std::vector<std::unique_ptr<TaskType>> tasks_to_schedule;
636   std::vector<ASBSBatch<TaskType>*> new_batches;
637   bool closed_batch = false;
638   {
639     mutex_lock l(mu_);
640     if (size > SchedulingCapacityLocked()) {
641       return errors::Unavailable("The batch scheduling queue is full");
642     }
643 
644     int remaining_batch_size =
645         current_batch_ == nullptr
646             ? options_.max_batch_size
647             : options_.max_batch_size - current_batch_->size();
648     if (options_.split_input_task_func == nullptr ||
649         size <= remaining_batch_size) {
650       // Either we don't allow task splitting or task fits within the current
651       // batch.
652       tasks_to_schedule.push_back(std::move(*task));
653     } else {
654       // Split task in order to completely fill the current batch.
655       // Beyond this point Schedule should not fail, as the caller has been
656       // promised that all of the split tasks will be scheduled.
657       TF_RETURN_IF_ERROR(options_.split_input_task_func(
658           task, remaining_batch_size, options_.max_batch_size,
659           &tasks_to_schedule));
660     }
661     for (auto& task : tasks_to_schedule) {
662       // Can't fit within current batch, close it off and try to create another.
663       if (current_batch_ &&
664           current_batch_->size() + task->size() > options_.max_batch_size) {
665         current_batch_->Close();
666         closed_batch = true;
667         current_batch_ = nullptr;
668       }
669       if (!current_batch_) {
670         num_enqueued_batches_++;
671         // batch.traceme_context_id connects TraceMeProducer and
672         // TraceMeConsumer.
673         // When multiple calls to "ASBS::Schedule" accumulate to one batch, they
674         // are processed in the same batch and should share traceme_context_id.
675         current_batch_ = new ASBSBatch<TaskType>(
676             this, scheduler_->GetEnv()->NowMicros(),
677             options_.batch_timeout_micros, NewTraceMeContextIdForBatch());
678         new_batches.push_back(current_batch_);
679       }
680 
681       // Annotate each task (corresponds to one call of schedule) with a
682       // TraceMeProducer.
683       profiler::TraceMeProducer trace_me(
684           [task_size = task->size()] {
685             return profiler::TraceMeEncode(
686                 "ASBSQueue::Schedule",
687                 {{"batching_input_task_size", task_size}});
688           },
689           profiler::ContextType::kAdaptiveSharedBatchScheduler,
690           this->current_batch_->traceme_context_id());
691       current_batch_->AddTask(std::move(task));
692       num_enqueued_tasks_++;
693       // If current_batch_ is now full, allow it to be processed immediately.
694       if (current_batch_->size() == options_.max_batch_size) {
695         current_batch_->Close();
696         closed_batch = true;
697         current_batch_ = nullptr;
698       }
699     }
700   }
701   // Scheduler functions must be called outside of lock, since they may call
702   // ReleaseBatch.
703   for (auto* batch : new_batches) {
704     scheduler_->AddBatch(batch);
705   }
706   if (closed_batch) {
707     scheduler_->MaybeScheduleClosedBatches();
708   }
709   return Status::OK();
710 }
711 
712 template <typename TaskType>
ReleaseBatch(const ASBSBatch<TaskType> * batch)713 void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
714   mutex_lock l(mu_);
715   num_enqueued_batches_--;
716   num_enqueued_tasks_ -= batch->num_tasks();
717   if (batch == current_batch_) {
718     current_batch_->Close();
719     current_batch_ = nullptr;
720   }
721 }
722 
723 template <typename TaskType>
NumEnqueuedTasks()724 size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
725   mutex_lock l(mu_);
726   return num_enqueued_tasks_;
727 }
728 
729 template <typename TaskType>
SchedulingCapacity()730 size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
731   mutex_lock l(mu_);
732   return SchedulingCapacityLocked();
733 }
734 
735 template <typename TaskType>
SchedulingCapacityLocked()736 size_t ASBSQueue<TaskType>::SchedulingCapacityLocked() const {
737   const int current_batch_capacity =
738       current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
739   const int spare_batches =
740       options_.max_enqueued_batches - num_enqueued_batches_;
741   return spare_batches * options_.max_batch_size + current_batch_capacity;
742 }
743 
744 template <typename TaskType>
745 // static
NewTraceMeContextIdForBatch()746 uint64 ASBSQueue<TaskType>::NewTraceMeContextIdForBatch() {
747   static std::atomic<uint64> traceme_context_id(0);
748   return traceme_context_id.fetch_add(1, std::memory_order_relaxed);
749 }
750 }  // namespace internal
751 }  // namespace serving
752 }  // namespace tensorflow
753 
754 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
755