1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrapper for prefetching_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.ops import iterator_ops 22from tensorflow.python.data.util import structure 23from tensorflow.python.eager import function 24from tensorflow.python.framework import device as framework_device 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import functional_ops 30from tensorflow.python.ops import gen_dataset_ops 31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 32from tensorflow.python.ops import resource_variable_ops 33from tensorflow.python.util.tf_export import tf_export 34 35 36@tf_export("data.experimental.prefetch_to_device") 37def prefetch_to_device(device, buffer_size=None): 38 """A transformation that prefetches dataset values to the given `device`. 39 40 NOTE: Although the transformation creates a `tf.data.Dataset`, the 41 transformation must be the final `Dataset` in the input pipeline. 42 43 Args: 44 device: A string. The name of a device to which elements will be prefetched. 45 buffer_size: (Optional.) The number of elements to buffer on `device`. 46 Defaults to an automatically chosen value. 47 48 Returns: 49 A `Dataset` transformation function, which can be passed to 50 `tf.data.Dataset.apply`. 51 """ 52 def _apply_fn(dataset): 53 return dataset.apply( 54 copy_to_device(target_device=device)).prefetch(buffer_size) 55 56 return _apply_fn 57 58 59@tf_export("data.experimental.copy_to_device") 60def copy_to_device(target_device, source_device="/cpu:0"): 61 """A transformation that copies dataset elements to the given `target_device`. 62 63 Args: 64 target_device: The name of a device to which elements will be copied. 65 source_device: The original device on which `input_dataset` will be placed. 66 67 Returns: 68 A `Dataset` transformation function, which can be passed to 69 `tf.data.Dataset.apply`. 70 """ 71 72 def _apply_fn(dataset): 73 options = dataset_ops.Options() 74 options.experimental_optimization.apply_default_optimizations = False 75 options.experimental_optimization.autotune = False 76 return _CopyToDeviceDataset( 77 dataset, target_device=target_device, 78 source_device=source_device).with_options(options) 79 80 return _apply_fn 81 82 83# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate 84# all inputs to the Op are in host memory, thereby avoiding some unnecessary 85# Sends and Recvs. 86class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset): 87 """A `Dataset` that copies elements to another device.""" 88 89 def __init__(self, input_dataset, target_device, source_device="/cpu:0"): 90 """Constructs a _CopyToDeviceDataset. 91 92 Args: 93 input_dataset: `Dataset` to be copied 94 target_device: The name of the device to which elements would be copied. 95 source_device: Device where input_dataset would be placed. 96 """ 97 self._input_dataset = input_dataset 98 self._target_device = target_device 99 spec = framework_device.DeviceSpec().from_string(self._target_device) 100 self._is_gpu_target = (spec.device_type == "GPU") 101 self._source_device_string = source_device 102 self._source_device = ops.convert_to_tensor(source_device) 103 104 wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant( 105 self._input_dataset._variant_tensor) # pylint: disable=protected-access 106 107 @function.defun() 108 def _init_func(): 109 """Creates an iterator for the input dataset. 110 111 Returns: 112 A `string` tensor that encapsulates the iterator created. 113 """ 114 ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant) 115 resource = gen_dataset_ops.anonymous_iterator( 116 **self._input_dataset._flat_structure) # pylint: disable=protected-access 117 with ops.control_dependencies( 118 [gen_dataset_ops.make_iterator(ds_variant, resource)]): 119 return gen_dataset_ops.iterator_to_string_handle(resource) 120 121 init_func_concrete = _init_func._get_concrete_function_internal() # pylint: disable=protected-access 122 123 @function.defun() 124 def _remote_init_func(): 125 return functional_ops.remote_call( 126 target=self._source_device, 127 args=init_func_concrete.captured_inputs, 128 Tout=[dtypes.string], 129 f=init_func_concrete) 130 131 self._init_func = _remote_init_func._get_concrete_function_internal() # pylint: disable=protected-access 132 self._init_captured_args = self._init_func.captured_inputs 133 134 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 135 def _next_func(string_handle): 136 """Calls get_next for created iterator. 137 138 Args: 139 string_handle: An iterator string handle created by _init_func 140 Returns: 141 The elements generated from `input_dataset` 142 """ 143 with ops.device(self._source_device_string): 144 iterator = iterator_ops.Iterator.from_string_handle( 145 string_handle, 146 dataset_ops.get_legacy_output_types(self), 147 dataset_ops.get_legacy_output_shapes(self), 148 dataset_ops.get_legacy_output_classes(self)) 149 return structure.to_tensor_list(self.element_spec, iterator.get_next()) 150 151 next_func_concrete = _next_func._get_concrete_function_internal() # pylint: disable=protected-access 152 153 @function.defun_with_attributes( 154 input_signature=[tensor_spec.TensorSpec([], dtypes.string)], 155 attributes={"experimental_ints_on_device": True}) 156 def _remote_next_func(string_handle): 157 return functional_ops.remote_call( 158 target=self._source_device, 159 args=[string_handle] + next_func_concrete.captured_inputs, 160 Tout=self._input_dataset._flat_types, # pylint: disable=protected-access 161 f=next_func_concrete) 162 163 self._next_func = _remote_next_func._get_concrete_function_internal() # pylint: disable=protected-access 164 self._next_captured_args = self._next_func.captured_inputs 165 166 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 167 def _finalize_func(string_handle): 168 """Destroys the iterator resource created. 169 170 Args: 171 string_handle: An iterator string handle created by _init_func 172 Returns: 173 Tensor constant 0 174 """ 175 iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( 176 string_handle, 177 **self._input_dataset._flat_structure) # pylint: disable=protected-access 178 with ops.control_dependencies([ 179 resource_variable_ops.destroy_resource_op( 180 iterator_resource, ignore_lookup_error=True)]): 181 return array_ops.constant(0, dtypes.int64) 182 183 finalize_func_concrete = _finalize_func._get_concrete_function_internal() # pylint: disable=protected-access 184 185 @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) 186 def _remote_finalize_func(string_handle): 187 return functional_ops.remote_call( 188 target=self._source_device, 189 args=[string_handle] + finalize_func_concrete.captured_inputs, 190 Tout=[dtypes.int64], 191 f=finalize_func_concrete) 192 193 self._finalize_func = _remote_finalize_func._get_concrete_function_internal( # pylint: disable=protected-access 194 ) 195 self._finalize_captured_args = self._finalize_func.captured_inputs 196 197 g = ops.get_default_graph() 198 self._init_func.add_to_graph(g) 199 self._next_func.add_to_graph(g) 200 self._finalize_func.add_to_graph(g) 201 # pylint: enable=protected-scope 202 203 with ops.device(self._target_device): 204 variant_tensor = gen_dataset_ops.generator_dataset( 205 self._init_captured_args, 206 self._next_captured_args, 207 self._finalize_captured_args, 208 init_func=self._init_func, 209 next_func=self._next_func, 210 finalize_func=self._finalize_func, 211 **self._input_dataset._flat_structure) # pylint: disable=protected-access 212 super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor) 213 214 # The one_shot_iterator implementation needs a 0 arg _make_dataset function 215 # that thereby captures all the inputs required to create the dataset. Since 216 # there are strings that are inputs to the GeneratorDataset which can't be 217 # placed on a GPU, this fails for the GPU case. Therefore, disabling it for 218 # GPU 219 def make_one_shot_iterator(self): 220 if self._is_gpu_target: 221 raise ValueError("Cannot create a one shot iterator when using " 222 "`tf.data.experimental.copy_to_device()` on GPU. Please " 223 "use `Dataset.make_initializable_iterator()` instead.") 224 else: 225 return super(_CopyToDeviceDataset, self).make_one_shot_iterator() 226 227 228class _MapOnGpuDataset(dataset_ops.UnaryDataset): 229 """A `Dataset` that maps a function over elements in its using a GPU.""" 230 231 def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): 232 """See `Dataset.map()` for details.""" 233 self._input_dataset = input_dataset 234 self._use_inter_op_parallelism = use_inter_op_parallelism 235 236 self._map_func = dataset_ops.StructuredFunctionWrapper( 237 map_func, 238 self._transformation_name(), 239 dataset=input_dataset, 240 defun_kwargs={"experimental_ints_on_device": True}) 241 variant_tensor = ged_ops.experimental_map_dataset( 242 self._input_dataset._variant_tensor, # pylint: disable=protected-access 243 self._map_func.function.captured_inputs, 244 f=self._map_func.function, 245 use_inter_op_parallelism=self._use_inter_op_parallelism, 246 **self._flat_structure) 247 super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor) 248 249 def _functions(self): 250 return [self._map_func] 251 252 @property 253 def element_spec(self): 254 return self._map_func.output_structure 255 256 def _transformation_name(self): 257 return "map_on_gpu()" 258 259 260def map_on_gpu(map_func): 261 """Maps `map_func` across the elements of this dataset. 262 263 NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs 264 `map_func` on GPU. It must be used after applying the 265 `tf.data.experimental.copy_to_device` transformation with a GPU device 266 argument. 267 268 Args: 269 map_func: A function mapping a nested structure of tensors (having shapes 270 and types defined by `self.output_shapes` and `self.output_types`) to 271 another nested structure of tensors. 272 273 Returns: 274 A `Dataset` transformation function, which can be passed to 275 `tf.data.Dataset.apply`. 276 """ 277 278 def _apply_fn(dataset): 279 return _MapOnGpuDataset(dataset, map_func) 280 281 return _apply_fn 282