1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrapper for prefetching_ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.ops import iterator_ops
22from tensorflow.python.eager import context
23from tensorflow.python.eager import function
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import functional_ops
30from tensorflow.python.ops import gen_dataset_ops
31from tensorflow.python.ops import resource_variable_ops
32
33
34class _PerDeviceGenerator(dataset_ops.DatasetV2):
35  """A `dummy` generator dataset."""
36
37  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
38               source_device, element_structure):
39    self._structure = element_structure
40
41    multi_device_iterator_string_handle = (
42        gen_dataset_ops.multi_device_iterator_to_string_handle(
43            multi_device_iterator_resource))
44
45    # TODO(b/124254153): Enable autograph once the overhead is low enough.
46    @function.defun(autograph=False)  # Pure graph code.
47    def _init_func():
48      return multi_device_iterator_string_handle
49
50    init_func_concrete = _init_func._get_concrete_function_internal()  # pylint: disable=protected-access
51
52    # TODO(b/124254153): Enable autograph once the overhead is low enough.
53    @function.defun(autograph=False)  # Pure graph code.
54    def _remote_init_func():
55      return functional_ops.remote_call(
56          target=source_device,
57          args=init_func_concrete.captured_inputs,
58          Tout=[dtypes.string],
59          f=init_func_concrete)
60
61    self._init_func = _remote_init_func._get_concrete_function_internal()  # pylint: disable=protected-access
62    self._init_captured_args = self._init_func.captured_inputs
63
64    # TODO(b/124254153): Enable autograph once the overhead is low enough.
65    @function.defun(
66        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
67        autograph=False)  # Pure graph code.
68    def _next_func(string_handle):
69      # pylint: disable=protected-access
70      multi_device_iterator = (
71          gen_dataset_ops.multi_device_iterator_from_string_handle(
72              string_handle=string_handle,
73              output_types=self._structure._flat_types,
74              output_shapes=self._structure._flat_shapes))
75      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
76          multi_device_iterator=multi_device_iterator,
77          shard_num=shard_num,
78          incarnation_id=incarnation_id,
79          output_types=self._structure._flat_types,
80          output_shapes=self._structure._flat_shapes)
81
82    next_func_concrete = _next_func._get_concrete_function_internal()  # pylint: disable=protected-access
83
84    # TODO(b/124254153): Enable autograph once the overhead is low enough.
85    @function.defun_with_attributes(
86        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
87        attributes={"experimental_ints_on_device": True},
88        autograph=False)  # Pure graph code.
89    def _remote_next_func(string_handle):
90      return functional_ops.remote_call(
91          target=source_device,
92          args=[string_handle] + next_func_concrete.captured_inputs,
93          Tout=self._structure._flat_types,  # pylint: disable=protected-access
94          f=next_func_concrete)
95
96    self._next_func = _remote_next_func._get_concrete_function_internal()  # pylint: disable=protected-access
97    self._next_captured_args = self._next_func.captured_inputs
98
99    self._incarnation_id_index = -1
100    for i, arg in enumerate(self._next_captured_args):
101      if arg == incarnation_id:
102        self._incarnation_id_index = i
103
104    # TODO(b/124254153): Enable autograph once the overhead is low enough.
105    @function.defun(
106        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
107        autograph=False)  # Pure graph code.
108    def _finalize_func(unused_string_handle):
109      return array_ops.constant(0, dtypes.int64)
110
111    finalize_func_concrete = _finalize_func._get_concrete_function_internal()  # pylint: disable=protected-access
112
113    # TODO(b/124254153): Enable autograph once the overhead is low enough.
114    @function.defun(
115        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
116        autograph=False)  # Pure graph code.
117    def _remote_finalize_func(string_handle):
118      return functional_ops.remote_call(
119          target=source_device,
120          args=[string_handle] + finalize_func_concrete.captured_inputs,
121          Tout=[dtypes.int64],
122          f=finalize_func_concrete)
123
124    self._finalize_func = _remote_finalize_func._get_concrete_function_internal(  # pylint: disable=protected-access
125    )
126    self._finalize_captured_args = self._finalize_func.captured_inputs
127
128    variant_tensor = gen_dataset_ops.generator_dataset(
129        self._init_captured_args,
130        self._next_captured_args,
131        self._finalize_captured_args,
132        init_func=self._init_func,
133        next_func=self._next_func,
134        finalize_func=self._finalize_func,
135        **dataset_ops.flat_structure(self))
136    super(_PerDeviceGenerator, self).__init__(variant_tensor)
137
138  def _inputs(self):
139    # TODO(b/116506223): Determine which datasets should be used as inputs here.
140    return []
141
142  @property
143  def _element_structure(self):
144    return self._structure
145
146
147class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
148  """Creates a _PerDeviceGenerator-like dataset with a new incarnation_id.
149
150  Re-uses the functions from the provided per_device_dataset and just switches
151  out the function argument corresponding to the incarnation_id.
152  """
153
154  def __init__(self, per_device_dataset, incarnation_id):
155    # pylint: disable=protected-access
156    self._structure = per_device_dataset._structure
157
158    self._init_func = per_device_dataset._init_func
159    self._init_captured_args = self._init_func.captured_inputs
160
161    self._next_func = per_device_dataset._next_func
162    self._next_captured_args = per_device_dataset._next_captured_args
163    # The captured arguments to the next_func are string_handle, incarnation_id.
164    # We update the incarnation id to the new one.
165    self._next_captured_args[
166        per_device_dataset._incarnation_id_index] = incarnation_id
167
168    self._finalize_func = per_device_dataset._finalize_func
169    self._finalize_captured_args = per_device_dataset._finalize_captured_args
170
171    variant_tensor = gen_dataset_ops.generator_dataset(
172        self._init_captured_args,
173        self._next_captured_args,
174        self._finalize_captured_args,
175        init_func=self._init_func,
176        next_func=self._next_func,
177        finalize_func=self._finalize_func,
178        **dataset_ops.flat_structure(self))
179    super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
180
181  def _inputs(self):
182    # TODO(b/116506223): Determine which datasets should be used as inputs here.
183    return []
184
185  @property
186  def _element_structure(self):
187    return self._structure
188
189
190class MultiDeviceIterator(object):
191  """An iterator over multiple devices."""
192
193  def __init__(self,
194               dataset,
195               devices,
196               max_buffer_size=1,
197               prefetch_buffer_size=1,
198               source_device="/cpu:0"):
199    """Constructs a MultiDeviceIterator.
200
201    Args:
202      dataset: The input dataset to be iterated over.
203      devices: The list of devices to fetch data to.
204      max_buffer_size: Maximum size of the host side per device buffer to keep.
205      prefetch_buffer_size: if > 1, then we setup a buffer on each device
206        to prefetch into.
207      source_device: The host device to place the `dataset` on.
208
209      In order to prevent deadlocks, if the prefetch_buffer_size is greater
210      than the max_buffer_size, we set the max_buffer_size to
211      prefetch_buffer_size.
212
213    Raises:
214      RuntimeError: If run in Eager mode.
215    """
216    self._dataset = dataset._apply_options()  # pylint: disable=protected-access
217    self._devices = devices
218    self._source_device = source_device
219    self._source_device_tensor = ops.convert_to_tensor(source_device)
220    self._max_buffer_size = max_buffer_size
221    self._prefetch_buffer_size = prefetch_buffer_size
222
223    if self._prefetch_buffer_size > self._max_buffer_size:
224      self._max_buffer_size = self._prefetch_buffer_size
225
226    # Create the MultiDeviceIterator.
227    with ops.device(self._source_device):
228      # TODO(b/121378567): Get rid of this shared_name hack.
229      shared_name = ""
230      if context.executing_eagerly():
231        shared_name = context.shared_name()
232      self._multi_device_iterator_resource = (
233          gen_dataset_ops.multi_device_iterator(
234              devices=self._devices,
235              shared_name=shared_name,
236              container="",
237              **dataset_ops.flat_structure(self._dataset)))
238      if context.executing_eagerly():
239        # Delete the resource when this object is deleted
240        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
241            handle=self._multi_device_iterator_resource,
242            handle_device=self._source_device)
243
244      # The incarnation ID is used to ensure consistency between the per-device
245      # iterators and the multi-device iterator.
246      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
247          self._dataset._variant_tensor,  # pylint: disable=protected-access
248          self._multi_device_iterator_resource,
249          max_buffer_size=self._max_buffer_size)
250
251    self._prototype_device_datasets = []
252    for i, device in enumerate(self._devices):
253      with ops.device(device):
254        ds = _PerDeviceGenerator(
255            i, self._multi_device_iterator_resource, self._incarnation_id,
256            self._source_device_tensor, self._dataset._element_structure)  # pylint: disable=protected-access
257        self._prototype_device_datasets.append(ds)
258
259    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
260    # initialize the device side of the pipeline. This would allow the
261    # MultiDeviceIterator to choose, for example, to move some transformations
262    # into the device side from its input. It might be useful in rewriting.
263    # Create the per device iterators.
264    self._device_iterators = []
265    for i, device in enumerate(self._devices):
266      with ops.device(device):
267        ds = self._create_device_dataset(i)
268        if context.executing_eagerly():
269          self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds))
270        else:
271          self._device_iterators.append(
272              dataset_ops.make_initializable_iterator(ds))
273
274    if not context.executing_eagerly():
275      device_iterator_initializers = [
276          iterator.initializer for iterator in self._device_iterators
277      ]
278      self._initializer = control_flow_ops.group(*device_iterator_initializers)
279
280  def _create_device_dataset(self, i):
281    """Uses _prototype_device_datasets[i] to build a dataset for the device."""
282    ds = self._prototype_device_datasets[i]
283    ds = _ReincarnatedPerDeviceGenerator(ds, self._incarnation_id)
284    if self._prefetch_buffer_size > 0:
285      ds = ds.prefetch(self._prefetch_buffer_size)
286    # TODO(jsimsa): Enable auto-tuning and optimizations when supported for
287    # non-CPU devices.
288    options = dataset_ops.Options()
289    options.experimental_optimization.apply_default_optimizations = False
290    options.experimental_optimization.autotune = False
291    ds = ds.with_options(options)
292    return ds
293
294  def get_next(self, device=None):
295    """Returns the next element given a `device`, else returns all in a list."""
296    if device is not None:
297      index = self._devices.index(device)
298      return self._device_iterators[index].get_next()
299
300    result = []
301    for i, device in enumerate(self._devices):
302      with ops.device(device):
303        result.append(self._device_iterators[i].get_next())
304    return result
305
306  def get_next_as_optional(self):
307    result = []
308    for i, device in enumerate(self._devices):
309      with ops.device(device):
310        result.append(iterator_ops.get_next_as_optional(
311            self._device_iterators[i]))
312    return result
313
314  @property
315  def initializer(self):
316    if context.executing_eagerly():
317      return control_flow_ops.no_op()
318    return self._initializer
319
320  def _eager_reset(self):
321    """Resets the MultiDeviceIterator in eager mode."""
322    if not context.executing_eagerly():
323      raise ValueError("Eager reset is only supported in eager mode.")
324    # pylint: disable=protected-access
325    self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
326        self._dataset._variant_tensor,
327        self._multi_device_iterator_resource,
328        max_buffer_size=self._max_buffer_size)
329    for i, device in enumerate(self._devices):
330      with ops.device(device):
331        ds = self._create_device_dataset(i)
332        # Reset the device iterator resources with the new dataset.
333        ds_variant = ds._variant_tensor
334        gen_dataset_ops.make_iterator(
335            ds_variant, self._device_iterators[i]._iterator_resource)
336
337  @property
338  def _element_structure(self):
339    return dataset_ops.get_structure(self._dataset)
340