1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for prefetching_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.ops import iterator_ops
22from tensorflow.python.data.util import structure
23from tensorflow.python.eager import function
24from tensorflow.python.framework import device as framework_device
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import functional_ops
30from tensorflow.python.ops import gen_dataset_ops
31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.util.tf_export import tf_export
34
35
36@tf_export("data.experimental.prefetch_to_device")
37def prefetch_to_device(device, buffer_size=None):
38  """A transformation that prefetches dataset values to the given `device`.
39
40  NOTE: Although the transformation creates a `tf.data.Dataset`, the
41  transformation must be the final `Dataset` in the input pipeline.
42
43  Args:
44    device: A string. The name of a device to which elements will be prefetched.
45    buffer_size: (Optional.) The number of elements to buffer on `device`.
46      Defaults to an automatically chosen value.
47
48  Returns:
49    A `Dataset` transformation function, which can be passed to
50    `tf.data.Dataset.apply`.
51  """
52  def _apply_fn(dataset):
53    return dataset.apply(
54        copy_to_device(target_device=device)).prefetch(buffer_size)
55
56  return _apply_fn
57
58
59@tf_export("data.experimental.copy_to_device")
60def copy_to_device(target_device, source_device="/cpu:0"):
61  """A transformation that copies dataset elements to the given `target_device`.
62
63  Args:
64    target_device: The name of a device to which elements will be copied.
65    source_device: The original device on which `input_dataset` will be placed.
66
67  Returns:
68    A `Dataset` transformation function, which can be passed to
69    `tf.data.Dataset.apply`.
70  """
71
72  def _apply_fn(dataset):
73    options = dataset_ops.Options()
74    options.experimental_optimization.apply_default_optimizations = False
75    options.experimental_optimization.autotune = False
76    return _CopyToDeviceDataset(
77        dataset, target_device=target_device,
78        source_device=source_device).with_options(options)
79
80  return _apply_fn
81
82
83# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
84# all inputs to the Op are in host memory, thereby avoiding some unnecessary
85# Sends and Recvs.
86class _CopyToDeviceDataset(dataset_ops.UnaryUnchangedStructureDataset):
87  """A `Dataset` that copies elements to another device."""
88
89  def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
90    """Constructs a _CopyToDeviceDataset.
91
92    Args:
93      input_dataset: `Dataset` to be copied
94      target_device: The name of the device to which elements would be copied.
95      source_device: Device where input_dataset would be placed.
96    """
97    self._input_dataset = input_dataset
98    self._target_device = target_device
99    spec = framework_device.DeviceSpec().from_string(self._target_device)
100    self._is_gpu_target = (spec.device_type == "GPU")
101    self._source_device_string = source_device
102    self._source_device = ops.convert_to_tensor(source_device)
103
104    wrap_ds_variant = gen_dataset_ops.wrap_dataset_variant(
105        self._input_dataset._variant_tensor)  # pylint: disable=protected-access
106
107    @function.defun()
108    def _init_func():
109      """Creates an iterator for the input dataset.
110
111      Returns:
112        A `string` tensor that encapsulates the iterator created.
113      """
114      ds_variant = gen_dataset_ops.unwrap_dataset_variant(wrap_ds_variant)
115      resource = gen_dataset_ops.anonymous_iterator(
116          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
117      with ops.control_dependencies(
118          [gen_dataset_ops.make_iterator(ds_variant, resource)]):
119        return gen_dataset_ops.iterator_to_string_handle(resource)
120
121    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access
122
123    @function.defun()
124    def _remote_init_func():
125      return functional_ops.remote_call(
126          target=self._source_device,
127          args=init_func_concrete.captured_inputs,
128          Tout=[dtypes.string],
129          f=init_func_concrete)
130
131    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
132    self._init_captured_args = self._init_func.captured_inputs
133
134    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
135    def _next_func(string_handle):
136      """Calls get_next for created iterator.
137
138      Args:
139        string_handle: An iterator string handle created by _init_func
140      Returns:
141        The elements generated from `input_dataset`
142      """
143      with ops.device(self._source_device_string):
144        iterator = iterator_ops.Iterator.from_string_handle(
145            string_handle,
146            dataset_ops.get_legacy_output_types(self),
147            dataset_ops.get_legacy_output_shapes(self),
148            dataset_ops.get_legacy_output_classes(self))
149      return structure.to_tensor_list(self.element_spec, iterator.get_next())
150
151    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access
152
153    @function.defun_with_attributes(
154        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
155        attributes={"experimental_ints_on_device": True})
156    def _remote_next_func(string_handle):
157      return functional_ops.remote_call(
158          target=self._source_device,
159          args=[string_handle] + next_func_concrete.captured_inputs,
160          Tout=self._input_dataset._flat_types,  # pylint: disable=protected-access
161          f=next_func_concrete)
162
163    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
164    self._next_captured_args = self._next_func.captured_inputs
165
166    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
167    def _finalize_func(string_handle):
168      """Destroys the iterator resource created.
169
170      Args:
171        string_handle: An iterator string handle created by _init_func
172      Returns:
173        Tensor constant 0
174      """
175      iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
176          string_handle,
177          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
178      with ops.control_dependencies([
179          resource_variable_ops.destroy_resource_op(
180              iterator_resource, ignore_lookup_error=True)]):
181        return array_ops.constant(0, dtypes.int64)
182
183    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access
184
185    @function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
186    def _remote_finalize_func(string_handle):
187      return functional_ops.remote_call(
188          target=self._source_device,
189          args=[string_handle] + finalize_func_concrete.captured_inputs,
190          Tout=[dtypes.int64],
191          f=finalize_func_concrete)
192
193    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
194    )
195    self._finalize_captured_args = self._finalize_func.captured_inputs
196
197    g = ops.get_default_graph()
198    self._init_func.add_to_graph(g)
199    self._next_func.add_to_graph(g)
200    self._finalize_func.add_to_graph(g)
201    # pylint: enable=protected-scope
202
203    with ops.device(self._target_device):
204      variant_tensor = gen_dataset_ops.generator_dataset(
205          self._init_captured_args,
206          self._next_captured_args,
207          self._finalize_captured_args,
208          init_func=self._init_func,
209          next_func=self._next_func,
210          finalize_func=self._finalize_func,
211          **self._input_dataset._flat_structure)  # pylint: disable=protected-access
212    super(_CopyToDeviceDataset, self).__init__(input_dataset, variant_tensor)
213
214  # The one_shot_iterator implementation needs a 0 arg _make_dataset function
215  # that thereby captures all the inputs required to create the dataset. Since
216  # there are strings that are inputs to the GeneratorDataset which can't be
217  # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
218  # GPU
219  def make_one_shot_iterator(self):
220    if self._is_gpu_target:
221      raise ValueError("Cannot create a one shot iterator when using "
222                       "`tf.data.experimental.copy_to_device()` on GPU. Please "
223                       "use `Dataset.make_initializable_iterator()` instead.")
224    else:
225      return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
226
227
228class _MapOnGpuDataset(dataset_ops.UnaryDataset):
229  """A `Dataset` that maps a function over elements in its using a GPU."""
230
231  def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
232    """See `Dataset.map()` for details."""
233    self._input_dataset = input_dataset
234    self._use_inter_op_parallelism = use_inter_op_parallelism
235
236    self._map_func = dataset_ops.StructuredFunctionWrapper(
237        map_func,
238        self._transformation_name(),
239        dataset=input_dataset,
240        defun_kwargs={"experimental_ints_on_device": True})
241    variant_tensor = ged_ops.experimental_map_dataset(
242        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
243        self._map_func.function.captured_inputs,
244        f=self._map_func.function,
245        use_inter_op_parallelism=self._use_inter_op_parallelism,
246        **self._flat_structure)
247    super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
248
249  def _functions(self):
250    return [self._map_func]
251
252  @property
253  def element_spec(self):
254    return self._map_func.output_structure
255
256  def _transformation_name(self):
257    return "map_on_gpu()"
258
259
260def map_on_gpu(map_func):
261  """Maps `map_func` across the elements of this dataset.
262
263  NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs
264  `map_func` on GPU. It must be used after applying the
265  `tf.data.experimental.copy_to_device` transformation with a GPU device
266  argument.
267
268  Args:
269    map_func: A function mapping a nested structure of tensors (having shapes
270      and types defined by `self.output_shapes` and `self.output_types`) to
271      another nested structure of tensors.
272
273  Returns:
274    A `Dataset` transformation function, which can be passed to
275    `tf.data.Dataset.apply`.
276  """
277
278  def _apply_fn(dataset):
279    return _MapOnGpuDataset(dataset, map_func)
280
281  return _apply_fn
282