1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Operations for automatic batching and unbatching.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import function 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_spec 24from tensorflow.python.ops import gen_batch_ops 25# go/tf-wildcard-import 26# pylint: disable=wildcard-import 27from tensorflow.python.ops.gen_batch_ops import * 28# pylint: enable=wildcard-import 29 30 31@ops.RegisterGradient("Batch") 32def _BatchGrad(op, *out_grads): # pylint: disable=invalid-name 33 """Gradient for batch op.""" 34 gradients = [] 35 for i in range(len(op.inputs)): 36 gradients.append( 37 gen_batch_ops.unbatch( 38 out_grads[i], 39 op.outputs[-2], 40 op.outputs[-1], 41 timeout_micros=op.get_attr("grad_timeout_micros"), 42 shared_name="batch_gradient_{}_{}".format(op.name, i))) 43 return gradients 44 45 46@ops.RegisterGradient("Unbatch") 47def _UnbatchGrad(op, grad): # pylint: disable=invalid-name 48 return [ 49 gen_batch_ops.unbatch_grad( 50 op.inputs[0], 51 op.inputs[1], 52 grad, 53 op.inputs[2], 54 shared_name="unbatch_gradient_{}".format(op.name)), None, None 55 ] 56 57 58def batch_function(num_batch_threads, 59 max_batch_size, 60 batch_timeout_micros, 61 allowed_batch_sizes=None, 62 max_enqueued_batches=10): 63 """Batches the computation done by the decorated function. 64 65 So, for example, in the following code 66 67 ```python 68 @batch_function(1, 2, 3) 69 def layer(a): 70 return tf.matmul(a, a) 71 72 b = layer(w) 73 ``` 74 75 if more than one session.run call is simultaneously trying to compute `b` 76 the values of `w` will be gathered, non-deterministically concatenated 77 along the first axis, and only one thread will run the computation. See the 78 documentation of the `Batch` op for more details. 79 80 Assumes that all arguments of the decorated function are Tensors which will 81 be batched along their first dimension. 82 83 SparseTensor is not supported. The return value of the decorated function 84 must be a Tensor or a list/tuple of Tensors. 85 86 Args: 87 num_batch_threads: Number of scheduling threads for processing batches 88 of work. Determines the number of batches processed in parallel. 89 max_batch_size: Batch sizes will never be bigger than this. 90 batch_timeout_micros: Maximum number of microseconds to wait before 91 outputting an incomplete batch. 92 allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, 93 does nothing. Otherwise, supplies a list of batch sizes, causing the op 94 to pad batches up to one of those sizes. The entries must increase 95 monotonically, and the final entry must equal max_batch_size. 96 max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. 97 98 Returns: 99 The decorated function will return the unbatched computation output Tensors. 100 """ 101 102 def decorator(fn): # pylint: disable=missing-docstring 103 104 def decorated(*args): # pylint: disable=missing-docstring 105 106 @function.defun() 107 def computation(*computation_args): 108 return fn(*computation_args) 109 110 computation = computation.get_concrete_function( 111 *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i)) 112 for i, x in enumerate(args)]) 113 114 with ops.name_scope("batch") as name: 115 for a in args: 116 if not isinstance(a, ops.Tensor): 117 raise ValueError("All arguments to functions decorated with " 118 "`batch_function` are supposed to be Tensors; " 119 "found %s" % repr(a)) 120 return gen_batch_ops.batch_function( 121 num_batch_threads=num_batch_threads, 122 max_batch_size=max_batch_size, 123 batch_timeout_micros=batch_timeout_micros, 124 allowed_batch_sizes=allowed_batch_sizes, 125 max_enqueued_batches=max_enqueued_batches, 126 shared_name=name, 127 f=computation, 128 in_tensors=list(args), 129 captured_tensors=computation.captured_inputs, 130 Tout=[o.dtype for o in computation.outputs]) 131 132 return decorated 133 134 return decorator 135 136 137def batch_function_v1(num_batch_threads, 138 max_batch_size, 139 batch_timeout_micros, 140 allowed_batch_sizes=None, 141 grad_timeout_micros=60 * 1000 * 1000, 142 unbatch_timeout_micros=60 * 1000 * 1000, 143 max_enqueued_batches=10): 144 """Batches the computation done by the decorated function. 145 146 This is the older version of batch_function(). Please use the former instead 147 of this. 148 149 Args: 150 num_batch_threads: Number of scheduling threads for processing batches 151 of work. Determines the number of batches processed in parallel. 152 max_batch_size: Batch sizes will never be bigger than this. 153 batch_timeout_micros: Maximum number of microseconds to wait before 154 outputting an incomplete batch. 155 allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, 156 does nothing. Otherwise, supplies a list of batch sizes, causing the op 157 to pad batches up to one of those sizes. The entries must increase 158 monotonically, and the final entry must equal max_batch_size. 159 grad_timeout_micros: The timeout to use for the gradient. See the 160 documentation of the unbatch op for more details. Defaults to 60s. 161 unbatch_timeout_micros: The timeout to use for unbatching. See the 162 documentation of the unbatch op for more details. Defaults to 60s. 163 max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. 164 165 Returns: 166 The decorated function will return the unbatched computation output Tensors. 167 """ 168 def decorator(f): # pylint: disable=missing-docstring 169 def decorated(*args): 170 with ops.name_scope("batch") as name: 171 for a in args: 172 if not isinstance(a, ops.Tensor): 173 raise ValueError("All arguments to functions decorated with " 174 "`batch_function` are supposed to be Tensors; " 175 "found %s" % repr(a)) 176 batched_tensors, batch_index, id_t = gen_batch_ops.batch( 177 args, 178 num_batch_threads=num_batch_threads, 179 max_batch_size=max_batch_size, 180 batch_timeout_micros=batch_timeout_micros, 181 max_enqueued_batches=max_enqueued_batches, 182 allowed_batch_sizes=allowed_batch_sizes, 183 grad_timeout_micros=grad_timeout_micros, 184 shared_name=name) 185 outputs = f(*batched_tensors) 186 if isinstance(outputs, ops.Tensor): 187 outputs_list = [outputs] 188 else: 189 outputs_list = outputs 190 with ops.name_scope("unbatch") as unbatch_name: 191 unbatched = [ 192 gen_batch_ops.unbatch(t, batch_index, id_t, 193 timeout_micros=unbatch_timeout_micros, 194 shared_name=unbatch_name + "/" + t.name) 195 for t in outputs_list] 196 if isinstance(outputs, ops.Tensor): 197 return unbatched[0] 198 return unbatched 199 return decorated 200 return decorator 201