1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/framework/ops_util.h"
20 #include "tensorflow/core/framework/tensor_util.h"
21 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/lib/monitoring/gauge.h"
24 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
25 #include "tensorflow/core/profiler/lib/traceme.h"
26 #include "tensorflow/core/profiler/lib/traceme_encode.h"
27 #include "tensorflow/core/util/incremental_barrier.h"
28 
29 namespace tensorflow {
30 namespace serving {
31 namespace {
32 
RecordPaddingSize(int32 padding_size,const string & model_name,int32 execution_batch_size,const string & op_name)33 void RecordPaddingSize(int32 padding_size, const string& model_name,
34                        int32 execution_batch_size, const string& op_name) {
35   static auto* cell = tensorflow::monitoring::PercentileSampler<3>::New(
36       {"/tensorflow/serving/batching/padding_size",
37        "Tracks the padding size distribution on batches by model_name (if "
38        "available).",
39        "model_name", "execution_batch_size", "op_name"},
40       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
41       /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
42   cell->GetCell(model_name, absl::StrCat(execution_batch_size), op_name)
43       ->Add(static_cast<double>(padding_size));
44 }
45 
RecordInputBatchSize(int32 batch_size,const string & model_name,const string & op_name)46 void RecordInputBatchSize(int32 batch_size, const string& model_name,
47                           const string& op_name) {
48   static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
49       {"/tensorflow/serving/batching/input_batch_size",
50        "Tracks the batch size distribution on the inputs by model_name (if "
51        "available).",
52        "model_name", "op_name"},
53       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
54       /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
55   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
56 }
57 
RecordProcessedBatchSize(int32 batch_size,const string & model_name,const string & op_name)58 void RecordProcessedBatchSize(int32 batch_size, const string& model_name,
59                               const string& op_name) {
60   static auto* cell = tensorflow::monitoring::PercentileSampler<2>::New(
61       {"/tensorflow/serving/batching/processed_batch_size",
62        "Tracks the batch size distribution on processing by model_name (if "
63        "available).",
64        "model_name", "op_name"},
65       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
66       /*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber);
67   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_size));
68 }
69 
RecordBatchDelayUs(int64 batch_delay_us,const string & model_name,const string & op_name)70 void RecordBatchDelayUs(int64 batch_delay_us, const string& model_name,
71                         const string& op_name) {
72   static auto* cell = monitoring::PercentileSampler<2>::New(
73       {"/tensorflow/serving/batching/batch_delay_us",
74        "Tracks the batching delay (in microseconds) for inputs by model_name "
75        "(if available).",
76        "model_name", "op_name"},
77       /*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
78       /*max_samples=*/1024, monitoring::UnitOfMeasure::kTime);
79   cell->GetCell(model_name, op_name)->Add(static_cast<double>(batch_delay_us));
80 }
81 
RecordBatchParamBatchTimeoutMicros(int64 batch_timeout_micros,const string & model_name,const string & op_name)82 void RecordBatchParamBatchTimeoutMicros(int64 batch_timeout_micros,
83                                         const string& model_name,
84                                         const string& op_name) {
85   static auto* cell = monitoring::Gauge<int64, 2>::New(
86       "/tensorflow/serving/batching/batch_timeout_micros",
87       "Tracks how long a request can wait before being processed by a batch.",
88       "model_name", "op_name");
89   cell->GetCell(model_name, op_name)->Set(batch_timeout_micros);
90 }
91 
RecordBatchParamMaxBatchSize(int64 max_batch_size,const string & model_name,const string & op_name)92 void RecordBatchParamMaxBatchSize(int64 max_batch_size,
93                                   const string& model_name,
94                                   const string& op_name) {
95   static auto* cell = monitoring::Gauge<int64, 2>::New(
96       "/tensorflow/serving/batching/max_batch_size",
97       "Tracks the maximum size of a batch.", "model_name", "op_name");
98   cell->GetCell(model_name, op_name)->Set(max_batch_size);
99 }
100 
RecordBatchParamMaxEnqueuedBatches(int64 max_enqueued_batches,const string & model_name,const string & op_name)101 void RecordBatchParamMaxEnqueuedBatches(int64 max_enqueued_batches,
102                                         const string& model_name,
103                                         const string& op_name) {
104   static auto* cell = monitoring::Gauge<int64, 2>::New(
105       "/tensorflow/serving/batching/max_enqueued_batches",
106       "Tracks the maximum number of enqueued batches.", "model_name",
107       "op_name");
108   cell->GetCell(model_name, op_name)->Set(max_enqueued_batches);
109 }
110 
RecordBatchParamAllowedBatchSizes(const string & allowed_batch_sizes,const string & model_name,const string & op_name)111 void RecordBatchParamAllowedBatchSizes(const string& allowed_batch_sizes,
112                                        const string& model_name,
113                                        const string& op_name) {
114   static auto* cell = monitoring::Gauge<string, 2>::New(
115       "/tensorflow/serving/batching/allowed_batch_sizes",
116       "Tracks the sizes that are allowed to form a batch.", "model_name",
117       "op_name");
118   cell->GetCell(model_name, op_name)->Set(allowed_batch_sizes);
119 }
120 
GetModelName(OpKernelContext * ctx)121 const string& GetModelName(OpKernelContext* ctx) {
122   static string* kModelNameUnset = new string("model_name_unset");
123   if (!ctx->session_metadata()) return *kModelNameUnset;
124   if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
125   return ctx->session_metadata()->name();
126 }
127 
128 }  // namespace
129 
130 std::unique_ptr<BatchResourceBase::BatchTask>
CreateSplitTask(int split_index,AsyncOpKernel::DoneCallback done_callback)131 BatchResourceBase::BatchTask::CreateSplitTask(
132     int split_index, AsyncOpKernel::DoneCallback done_callback) {
133   std::unique_ptr<BatchTask> task = CreateDerivedTask();
134 
135   task->guid = this->guid;
136   task->propagated_context = Context(ContextKind::kThread);
137   task->inputs.reserve(this->inputs.size());
138   task->captured_inputs = this->captured_inputs;
139   task->context = this->context;
140   task->done_callback = done_callback;
141   task->split_index = split_index;
142   task->output = this->output;
143   task->status = this->status;
144   task->is_partial = true;
145   task->start_time = this->start_time;
146 
147   return task;
148 }
149 
150 using ::tensorflow::concat_split_util::Concat;
151 using ::tensorflow::concat_split_util::Split;
152 using TensorMatrix = std::vector<std::vector<Tensor>>;
153 
RegisterInput(int64 guid,OpKernelContext * context,const string & batcher_queue_name,AsyncOpKernel::DoneCallback done_callback)154 Status BatchResourceBase::RegisterInput(
155     int64 guid, OpKernelContext* context, const string& batcher_queue_name,
156     AsyncOpKernel::DoneCallback done_callback) {
157   std::unique_ptr<BatchTask> batch_components;
158   TF_RETURN_IF_ERROR(CreateBatchTask(context, &batch_components));
159   batch_components->start_time = EnvTime::NowNanos();
160   batch_components->guid = guid;
161   batch_components->propagated_context = Context(ContextKind::kThread);
162   OpInputList tensors;
163   TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
164   batch_components->inputs.reserve(tensors.size());
165   for (const Tensor& tensor : tensors) {
166     if (tensor.shape().dims() == 0) {
167       return errors::InvalidArgument(
168           "Batching input tensors must have at least one dimension");
169     }
170     if (tensors.size() >= 2 &&
171         tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
172       return errors::InvalidArgument(
173           "Batching input tensors supplied in a given op invocation must "
174           "have equal 0th-dimension size");
175     }
176     batch_components->inputs.push_back(tensor);
177   }
178   RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context),
179                        context->op_kernel().name_view().data());
180   RecordBatchParamBatchTimeoutMicros(
181       batcher_queue_options_.batch_timeout_micros, GetModelName(context),
182       context->op_kernel().name_view().data());
183   RecordBatchParamMaxBatchSize(batcher_queue_options_.max_execution_batch_size,
184                                GetModelName(context),
185                                context->op_kernel().name_view().data());
186   RecordBatchParamMaxEnqueuedBatches(
187       batcher_queue_options_.max_enqueued_batches, GetModelName(context),
188       context->op_kernel().name_view().data());
189   RecordBatchParamAllowedBatchSizes(allowed_batch_sizes_str_,
190                                     GetModelName(context),
191                                     context->op_kernel().name_view().data());
192 
193   // Degenerate case where the input is empty. Just return an empty tensor.
194   if (tensors[0].shape().dim_size(0) == 0) {
195     for (int i = 0; i < context->num_outputs(); i++) {
196       Tensor* empty_output;
197       AllocatorAttributes cpu_alloc;
198       cpu_alloc.set_on_host(true);
199       TF_RETURN_IF_ERROR(context->allocate_output(i, TensorShape({0}),
200                                                   &empty_output, cpu_alloc));
201     }
202     done_callback();
203     return Status::OK();
204   }
205   OpInputList captured_tensors;
206   const auto captured_status =
207       context->input_list("captured_tensors", &captured_tensors);
208   if (captured_status.ok()) {
209     batch_components->captured_inputs.reserve(captured_tensors.size());
210     for (const Tensor& captured_tensor : captured_tensors) {
211       batch_components->captured_inputs.push_back(captured_tensor);
212     }
213   }
214   batch_components->context = context;
215   batch_components->done_callback = std::move(done_callback);
216   batch_components->split_index = 0;
217   batch_components->output = std::make_shared<TensorMatrix>();
218   batch_components->status = std::make_shared<ThreadSafeStatus>();
219 
220   BatcherQueueT* batcher_queue;
221   TF_RETURN_IF_ERROR(
222       LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
223   return batcher_queue->Schedule(&batch_components);
224 }
225 
226 /*static*/ BatchResourceBase::BatcherT::QueueOptions
GetBatcherQueueOptions(int32 num_batch_threads,int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,const std::vector<int32> & allowed_batch_sizes,bool enable_large_batch_splitting)227 BatchResourceBase::GetBatcherQueueOptions(
228     int32 num_batch_threads, int32 max_batch_size, int32 batch_timeout_micros,
229     int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
230     bool enable_large_batch_splitting) {
231   BatcherT::QueueOptions batcher_queue_options;
232   batcher_queue_options.input_batch_size_limit = max_batch_size;
233   batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
234   batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
235   batcher_queue_options.enable_large_batch_splitting =
236       enable_large_batch_splitting;
237   if (enable_large_batch_splitting) {
238     batcher_queue_options.split_input_task_func =
239         [](std::unique_ptr<BatchTask>* input_task,
240            int open_batch_remaining_slot, int max_batch_size,
241            std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
242       return SplitInputTask(input_task, open_batch_remaining_slot,
243                             max_batch_size, output_tasks);
244     };
245 
246     if (allowed_batch_sizes.empty()) {
247       batcher_queue_options.max_execution_batch_size = max_batch_size;
248     } else {
249       batcher_queue_options.max_execution_batch_size =
250           *allowed_batch_sizes.rbegin();
251     }
252   }
253 
254   return batcher_queue_options;
255 }
256 
257 /*static*/ BatchResourceBase::AdaptiveBatcherT::QueueOptions
GetAdaptiveBatcherQueueOptions(int32 max_batch_size,int32 batch_timeout_micros,int32 max_enqueued_batches,bool enable_large_batch_splitting,const std::vector<int32> & allowed_batch_sizes)258 BatchResourceBase::GetAdaptiveBatcherQueueOptions(
259     int32 max_batch_size, int32 batch_timeout_micros,
260     int32 max_enqueued_batches, bool enable_large_batch_splitting,
261     const std::vector<int32>& allowed_batch_sizes) {
262   AdaptiveBatcherT::QueueOptions batcher_queue_options;
263   batcher_queue_options.max_input_task_size =
264       absl::make_optional(max_batch_size);
265   batcher_queue_options.max_enqueued_batches = max_enqueued_batches;
266   batcher_queue_options.batch_timeout_micros = batch_timeout_micros;
267   if (allowed_batch_sizes.empty()) {
268     batcher_queue_options.max_batch_size = max_batch_size;
269   } else {
270     batcher_queue_options.max_batch_size = *allowed_batch_sizes.rbegin();
271   }
272 
273   if (enable_large_batch_splitting) {
274     batcher_queue_options.split_input_task_func =
275         [](std::unique_ptr<BatchTask>* input_task,
276            int open_batch_remaining_slot, int max_batch_size,
277            std::vector<std::unique_ptr<BatchTask>>* output_tasks) -> Status {
278       return SplitInputTask(input_task, open_batch_remaining_slot,
279                             max_batch_size, output_tasks);
280     };
281   }
282 
283   return batcher_queue_options;
284 }
285 
ValidateBatch(const BatchT & batch)286 /*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) {
287   for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
288     const BatchResourceBase::BatchTask& task = batch.task(task_idx);
289 
290     if (task.inputs.size() != batch.task(0).inputs.size()) {
291       return errors::InvalidArgument(
292           "Batching inputs must have equal number of edges");
293     }
294   }
295 
296   return Status::OK();
297 }
298 
299 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
300 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
301 // returns 'batch_size'.
RoundToLowestAllowedBatchSize(int batch_size) const302 int BatchResourceBase::RoundToLowestAllowedBatchSize(int batch_size) const {
303   if (allowed_batch_sizes_.empty()) {
304     return batch_size;
305   }
306   for (int allowed_size : allowed_batch_sizes_) {
307     if (allowed_size >= batch_size) {
308       return allowed_size;
309     }
310   }
311   LOG(ERROR) << "Batch size " << batch_size
312              << " is greater than largest allowed size; "
313                 "ignoring allowed sizes constraint.";
314   return batch_size;
315 }
316 
ConcatInputTensors(const BatchT & batch,OpKernelContext * context,std::vector<Tensor> * concatenated_tensors) const317 Status BatchResourceBase::ConcatInputTensors(
318     const BatchT& batch, OpKernelContext* context,
319     std::vector<Tensor>* concatenated_tensors) const {
320   if (batch.num_tasks() == 0) {
321     return errors::InvalidArgument("Empty batch.");
322   }
323 
324   const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
325   const int padding_amount = padded_batch_size - batch.size();
326   profiler::TraceMe trace_me([padded_batch_size, padding_amount]() {
327     return profiler::TraceMeEncode(
328         "ConcatInputTensors", {{"batch_size_after_padding", padded_batch_size},
329                                {"padding_amount", padding_amount}});
330   });
331   RecordPaddingSize(padding_amount, GetModelName(context), padded_batch_size,
332                     context->op_kernel().name_view().data());
333   RecordProcessedBatchSize(padded_batch_size, GetModelName(context),
334                            context->op_kernel().name_view().data());
335 
336   // All tasks should have the same number of input edges.
337   const int num_inputs = batch.task(0).inputs.size();
338   concatenated_tensors->reserve(num_inputs);
339 
340   // Process each input one at a time (the typical case has just one).
341   for (int i = 0; i < num_inputs; ++i) {
342     // Concatenate the tasks ith input tensors into a big output tensor.
343     std::vector<Tensor> to_concatenate;
344     to_concatenate.reserve(batch.num_tasks());
345     for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
346       to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
347     }
348 
349     // Add padding as needed. Use the first row of the first task's tensor as
350     // the data for padding.
351     if (padding_amount > 0) {
352       const Tensor& padding_source = batch.task(0).inputs.at(i);
353       Tensor padding;
354       if (padding_source.shape().dim_size(0) == 0) {
355         return errors::InvalidArgument(
356             "Cannot use an empty tensor with zero rows as padding when "
357             "batching. (Input ",
358             i, " got shape ", padding_source.shape().DebugString(), ".)");
359       }
360       if (padding_source.shape().dim_size(0) == 1) {
361         padding = padding_source;
362       } else {
363         padding = padding_source.Slice(0, 1);
364       }
365       for (int i = 0; i < padding_amount; ++i) {
366         to_concatenate.push_back(padding);
367       }
368     }
369 
370     Tensor concatenated_tensor;
371     Status concat_status =
372         Concat(context, to_concatenate, &concatenated_tensor);
373     TF_RETURN_IF_ERROR(concat_status);
374     concatenated_tensors->push_back(concatenated_tensor);
375   }
376   return Status::OK();
377 }
378 
SplitInputTask(std::unique_ptr<BatchTask> * input_task_ptr,int open_batch_remaining_slot,int max_batch_size,std::vector<std::unique_ptr<BatchTask>> * output_tasks)379 /*static*/ Status BatchResourceBase::SplitInputTask(
380     std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
381     int max_batch_size, std::vector<std::unique_ptr<BatchTask>>* output_tasks) {
382   BatchTask& input_task = *(*input_task_ptr);
383   const int64 input_task_size = input_task.size();
384 
385   DCHECK_GT(input_task_size, open_batch_remaining_slot);
386 
387   std::shared_ptr<ThreadSafeStatus> shared_status = input_task.status;
388 
389   // `split_task_done_callback` runs only after all splitted tasks are
390   // complete.
391   std::function<void()> split_task_done_callback =
392       [done_callback = input_task.done_callback, output = input_task.output,
393        op_kernel_context = input_task.context, status = shared_status]() {
394         const int num_output = op_kernel_context->num_outputs();
395         for (int i = 0; i < num_output; ++i) {
396           Tensor output_tensor;
397 
398           // Concat would memcpy each input tensor to one output tensor.
399           // In this context, Concat can be further optimized to get rid of
400           // some (probably all) memcpy when input tensors are slices of
401           // another copy.
402           std::vector<Tensor> to_concatenate;
403           to_concatenate.reserve(output->size());
404           for (int j = 0; j < output->size(); ++j) {
405             to_concatenate.push_back(std::move((*output)[j][i]));
406           }
407           const auto concat_status =
408               Concat(op_kernel_context, to_concatenate, &output_tensor);
409           if (!concat_status.ok()) {
410             status->Update(concat_status);
411           }
412 
413           op_kernel_context->set_output(i, std::move(output_tensor));
414         }
415         op_kernel_context->SetStatus(status->status());
416         done_callback();
417       };
418   IncrementalBarrier barrier(split_task_done_callback);
419 
420   std::vector<int64> output_task_sizes;
421 
422   if (open_batch_remaining_slot > 0) {
423     output_task_sizes.push_back(open_batch_remaining_slot);
424   }
425 
426   for (int left_task_size = input_task_size - open_batch_remaining_slot;
427        left_task_size > 0; left_task_size -= max_batch_size) {
428     int next_task_size = std::min(left_task_size, max_batch_size);
429     output_task_sizes.push_back(next_task_size);
430   }
431 
432   const int output_task_num = output_task_sizes.size();
433   input_task.output->resize(output_task_num);
434 
435   for (int i = 0; i < output_task_num; ++i) {
436     (*input_task.output)[i].resize(input_task.context->num_outputs());
437   }
438 
439   output_tasks->reserve(output_task_num);
440   for (int i = 0; i < output_task_num; i++) {
441     output_tasks->push_back(input_task.CreateSplitTask(i, barrier.Inc()));
442   }
443 
444   const int num_input_tensors = input_task.inputs.size();
445 
446   // Splits each input tensor according to `output_task_sizes`, and
447   // initializes input of `output_tasks` with split results.
448   for (int i = 0; i < num_input_tensors; ++i) {
449     std::vector<Tensor> split_tensors;
450     const Tensor& input_tensor = input_task.inputs[i];
451     // TODO(b/154140947):
452     // Figure out the optimal implementation of Split, by using
453     // 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible.
454     const Status split_status = Split(input_task.context, input_tensor,
455                                       output_task_sizes, &split_tensors);
456     if (!split_status.ok()) {
457       return errors::Internal(
458           "When splitting input, Tensor split operation failed: ",
459           split_status.ToString());
460     }
461     if (split_tensors.size() != output_task_sizes.size()) {
462       return errors::Internal(
463           "When splitting input, tensor split operation did not work as "
464           "expected; got ",
465           split_tensors.size(), " splits; expected ", output_task_sizes.size());
466     }
467     for (int j = 0; j < output_tasks->size(); ++j) {
468       BatchTask& output_task = *((*output_tasks)[j]);
469       auto moved_tensor_iter = std::next(split_tensors.begin(), j);
470       std::move(moved_tensor_iter, moved_tensor_iter + 1,
471                 std::back_inserter(output_task.inputs));
472     }
473   }
474   return Status::OK();
475 }
476 
SplitOutputTensors(const std::vector<Tensor> & combined_outputs,BatchT * batch) const477 Status BatchResourceBase::SplitOutputTensors(
478     const std::vector<Tensor>& combined_outputs, BatchT* batch) const {
479   DCHECK_GE(batch->num_tasks(), 1);
480   if (batch->num_tasks() < 1) {
481     return errors::Internal("Batch size expected to be positive; was ",
482                             batch->num_tasks());
483   }
484 
485   std::vector<int64> task_sizes_plus_optional_padding;
486   task_sizes_plus_optional_padding.reserve(batch->num_tasks());
487   for (int i = 0; i < batch->num_tasks(); ++i) {
488     task_sizes_plus_optional_padding.push_back(batch->task(i).size());
489   }
490   const int padding_size =
491       RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
492   if (padding_size > 0) {
493     task_sizes_plus_optional_padding.push_back(padding_size);
494   }
495 
496   // For each output tensor name, a divided-up tensor with one entry per task.
497   std::map<string, std::vector<Tensor>> split_tensors;
498 
499   DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
500   int combined_outputs_size = combined_outputs.size();
501   if (combined_outputs_size != batch->task(0).context->num_outputs()) {
502     return errors::Internal("Wrong number of batched output tensors");
503   }
504 
505   // Generate 'split_tensors' and populate the context outputs.
506   for (int i = 0, iter_limit = combined_outputs.size(); i < iter_limit; ++i) {
507     const Tensor& output_tensor = combined_outputs[i];
508     if (output_tensor.shape().dims() == 0) {
509       return errors::FailedPrecondition(
510           "Batched output tensor has 0 dimensions");
511     }
512     if (output_tensor.shape().dim_size(0) !=
513         static_cast<int64>(batch->size() + padding_size)) {
514       return errors::FailedPrecondition(
515           "Batched output tensor's 0th dimension does not equal the sum of "
516           "the 0th dimension sizes of the input tensors");
517     }
518 
519     std::vector<Tensor> split_tensor;
520     const Status split_status = tensor::Split(
521         output_tensor, task_sizes_plus_optional_padding, &split_tensor);
522     DCHECK(split_status.ok()) << split_status.ToString();
523     if (!split_status.ok()) {
524       return errors::Internal("Tensor split operation failed: ",
525                               split_status.ToString());
526     }
527     DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
528     if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
529       return errors::Internal(
530           "Tensor split operation did not work as expected; got ",
531           split_tensor.size(), " splits; expected ",
532           task_sizes_plus_optional_padding.size());
533     }
534 
535     // Ignore a possible final split_tensors entry containing the padding.
536     for (int j = 0; j < batch->num_tasks(); ++j) {
537       BatchTask& task = *(batch->mutable_task(j));
538       if (task.is_partial) {
539         std::vector<Tensor>& tensor_vector = (*task.output)[task.split_index];
540         tensor_vector[i] = std::move(split_tensor[j]);
541       } else {
542         task.context->set_output(i, split_tensor[j]);
543       }
544     }
545   }
546 
547   return Status::OK();
548 }
549 
ProcessFuncBatch(std::unique_ptr<BatchT> batch) const550 void BatchResourceBase::ProcessFuncBatch(std::unique_ptr<BatchT> batch) const {
551   if (batch->empty()) {
552     return;
553   }
554 
555   // We use the 'propagated_context' from one of the threads which setup one
556   // of the tasks. This will propagate any common context over all the threads
557   // which are running this Session, of which this BatchOp is a part.
558   WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
559 
560   auto& last_task = batch->task(batch->num_tasks() - 1);
561   OpKernelContext* last_task_context = last_task.context;
562 
563   // Regardless of the outcome, we need to propagate the status to the
564   // individual tasks and signal that they are done. We use MakeCleanup() to
565   // ensure that this happens no matter how we exit the method below.
566   Status status;
567   bool cleanup_done = false;
568   auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
569     if (cleanup_done) {
570       return;
571     }
572     for (int i = 0; i < batch->num_tasks(); ++i) {
573       if (batch->task(i).is_partial) {
574         batch->mutable_task(i)->status->Update(status);
575       } else {
576         batch->mutable_task(i)->context->SetStatus(status);
577       }
578 
579       batch->mutable_task(i)->done_callback();
580     }
581     cleanup_done = true;
582   };
583 
584   auto finally =
585       gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
586 
587   status = ValidateBatch(*batch);
588   if (!status.ok()) {
589     return;
590   }
591 
592   std::vector<Tensor> concatenated_tensors;
593   status = ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
594   if (!status.ok()) {
595     return;
596   }
597 
598   std::vector<Tensor> combined_outputs;
599   std::vector<Tensor> args(concatenated_tensors.begin(),
600                            concatenated_tensors.end());
601   const auto& captured_inputs =
602       batch->task(batch->num_tasks() - 1).captured_inputs;
603   args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
604 
605   uint64 current_time = EnvTime::NowNanos();
606   const string& model_name = GetModelName(last_task_context);
607   for (int i = 0; i < batch->num_tasks(); ++i) {
608     RecordBatchDelayUs((current_time - batch->task(i).start_time) * 1e-3,
609                        model_name,
610                        last_task_context->op_kernel().name_view().data());
611   }
612   // Releases the cleanup method here, because the callback of the function
613   // library runtime will handle it now.
614   finally.release();
615   ProcessFuncBatchImpl(
616       last_task, args, &combined_outputs, [&](const Status& run_status) {
617         Status final_status;
618         auto run_finally = gtl::MakeCleanup([&]() {
619           // We do the cleanup here as an optimization, so that
620           // it runs in the underlying TF inter-op threadpool.
621           // Running it in the threadpool, let's the ensuing
622           // ops be scheduled faster, because the executor will
623           // add them to the front of the threadpool's task
624           // queue rather than the end.
625           cleanup_fn(final_status);
626         });
627         final_status = run_status;
628         if (!final_status.ok()) {
629           return;
630         }
631         final_status = SplitOutputTensors(combined_outputs, batch.get());
632       });
633 }
634 
635 // Processes a batch of one or more BatchTask entries.
ProcessBatch(std::unique_ptr<BatchT> batch) const636 void BatchResourceBase::ProcessBatch(std::unique_ptr<BatchT> batch) const {
637   if (batch->empty()) {
638     return;
639   }
640 
641   WithContext wc(batch->task(batch->num_tasks() - 1).propagated_context);
642 
643   OpKernelContext* last_task_context =
644       batch->task(batch->num_tasks() - 1).context;
645   AsyncOpKernel::DoneCallback last_task_callback =
646       batch->task(batch->num_tasks() - 1).done_callback;
647 
648   OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
649                        last_task_callback);
650 
651   // All tasks should have the same number of input edges.
652   const int num_input_edges = batch->task(0).inputs.size();
653   std::vector<Tensor> concatenated_tensors;
654   const Status concat_status =
655       ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
656   OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
657 
658   // Process each input edge one at a time (the typical case has just one).
659   for (int i = 0; i < num_input_edges; ++i) {
660     last_task_context->set_output(i, concatenated_tensors[i]);
661 
662     // Emit batch->num_tasks() - 1 empty output tensors.
663     for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
664       const BatchTask& task = batch->task(task_idx);
665       TensorShape output_shape(task.inputs[i].shape());
666       output_shape.set_dim(0, 0);
667       Tensor* output = nullptr;
668       OP_REQUIRES_OK_ASYNC(
669           task.context, task.context->allocate_output(i, output_shape, &output),
670           task.done_callback);
671     }
672   }
673   // Emit batch->num_tasks() - 1 empty index tensors.
674   for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
675     const BatchTask& task = batch->task(task_idx);
676     TensorShape index_shape({0, 3});
677     Tensor* output = nullptr;
678     OP_REQUIRES_OK_ASYNC(
679         task.context,
680         task.context->allocate_output(num_input_edges, index_shape, &output),
681         task.done_callback);
682   }
683   // Emit all ID tensors.
684   for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
685     const BatchTask& task = batch->task(task_idx);
686     Tensor* id;
687     OP_REQUIRES_OK_ASYNC(task.context,
688                          task.context->allocate_output(num_input_edges + 1,
689                                                        TensorShape({}), &id),
690                          task.done_callback);
691     id->scalar<int64>()() = task.guid;
692   }
693   OP_REQUIRES_OK_ASYNC(
694       last_task_context,
695       EmitIndexTensor(last_task_context, *batch, num_input_edges),
696       last_task_callback);
697 
698   // Signal done for each element of the batch. (At this point, the contexts
699   // are no longer guaranteed to remain live.)
700   for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
701     batch->mutable_task(task_idx)->done_callback();
702   }
703 }
704 
EmitIndexTensor(OpKernelContext * context,const BatchT & batch,int output_index)705 /*static*/ Status BatchResourceBase::EmitIndexTensor(OpKernelContext* context,
706                                                      const BatchT& batch,
707                                                      int output_index) {
708   const TensorShape index_shape({batch.num_tasks(), 3});
709   Tensor* index = nullptr;
710   TF_RETURN_IF_ERROR(
711       context->allocate_output(output_index, index_shape, &index));
712   auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
713   size_t offset = 0;
714   for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
715     const BatchTask& task = batch.task(task_idx);
716     index_flat(task_idx, 0) = task.guid;
717     index_flat(task_idx, 1) = offset;
718     index_flat(task_idx, 2) = offset + task.size();
719     offset += task.size();
720   }
721   return Status::OK();
722 }
723 
724 // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
725 // creates it.
LookupOrCreateBatcherQueue(const string & queue_name,BatcherQueueT ** queue)726 Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name,
727                                                      BatcherQueueT** queue) {
728   mutex_lock l(batcher_queues_mu_);
729 
730   auto it = batcher_queues_.find(queue_name);
731   if (it != batcher_queues_.end()) {
732     *queue = it->second.get();
733     return Status::OK();
734   }
735 
736   std::unique_ptr<BatcherQueueT> new_queue;
737   auto process_batch_callback = [this](std::unique_ptr<BatchT> batch) {
738     if (!has_process_batch_function_) {
739       ProcessBatch(std::move(batch));
740     } else {
741       ProcessFuncBatch(std::move(batch));
742     }
743   };
744   if (batcher_) {
745     TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
746                                           process_batch_callback, &new_queue));
747   } else if (adaptive_batcher_) {
748     TF_RETURN_IF_ERROR(adaptive_batcher_->AddQueue(
749         adaptive_batcher_queue_options_, process_batch_callback, &new_queue));
750   } else {
751     return errors::Internal("No batcher defined.");
752   }
753   *queue = new_queue.get();
754   batcher_queues_[queue_name] = std::move(new_queue);
755   return Status::OK();
756 }
757 
CreateBatchTask(OpKernelContext * context,std::unique_ptr<BatchResourceBase::BatchTask> * output) const758 Status BatchResourceBase::CreateBatchTask(
759     OpKernelContext* context,
760     std::unique_ptr<BatchResourceBase::BatchTask>* output) const {
761   *output = absl::make_unique<BatchResourceBase::BatchTask>();
762   return Status::OK();
763 }
764 
765 }  // namespace serving
766 }  // namespace tensorflow
767