1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for reader Datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21
22from tensorflow.python import tf2
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.data.util import convert
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import gen_dataset_ops
31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
32from tensorflow.python.util import nest
33from tensorflow.python.util.tf_export import tf_export
34
35_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024  # 256 KB
36
37
38def _normalise_fspath(path):
39  """Convert pathlib-like objects to str (__fspath__ compatibility, PEP 519)."""
40  return os.fspath(path) if isinstance(path, os.PathLike) else path
41
42
43def _create_or_validate_filenames_dataset(filenames):
44  """Creates (or validates) a dataset of filenames.
45
46  Args:
47    filenames: Either a list or dataset of filenames. If it is a list, it is
48      convert to a dataset. If it is a dataset, its type and shape is validated.
49
50  Returns:
51    A dataset of filenames.
52  """
53  if isinstance(filenames, dataset_ops.DatasetV2):
54    if dataset_ops.get_legacy_output_types(filenames) != dtypes.string:
55      raise TypeError(
56          "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.")
57    if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with(
58        tensor_shape.TensorShape([])):
59      raise TypeError(
60          "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
61          "elements.")
62  else:
63    filenames = nest.map_structure(_normalise_fspath, filenames)
64    filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string)
65    if filenames.dtype != dtypes.string:
66      raise TypeError(
67          "`filenames` must be a `tf.Tensor` of dtype `tf.string` dtype."
68          " Got {}".format(filenames.dtype))
69    filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
70    filenames = dataset_ops.DatasetV2.from_tensor_slices(filenames)
71
72  return filenames
73
74
75def _create_dataset_reader(dataset_creator, filenames, num_parallel_reads=None):
76  """Creates a dataset that reads the given files using the given reader.
77
78  Args:
79    dataset_creator: A function that takes in a single file name and returns a
80      dataset.
81    filenames: A `tf.data.Dataset` containing one or more filenames.
82    num_parallel_reads: The number of parallel reads we should do.
83
84  Returns:
85    A `Dataset` that reads data from `filenames`.
86  """
87
88  def read_one_file(filename):
89    filename = ops.convert_to_tensor(filename, dtypes.string, name="filename")
90    return dataset_creator(filename)
91
92  if num_parallel_reads is None:
93    return filenames.flat_map(read_one_file)
94  elif num_parallel_reads == dataset_ops.AUTOTUNE:
95    return filenames.interleave(
96        read_one_file, num_parallel_calls=num_parallel_reads)
97  else:
98    return ParallelInterleaveDataset(
99        filenames,
100        read_one_file,
101        cycle_length=num_parallel_reads,
102        block_length=1,
103        sloppy=False,
104        buffer_output_elements=None,
105        prefetch_input_elements=None)
106
107
108class _TextLineDataset(dataset_ops.DatasetSource):
109  """A `Dataset` comprising records from one or more text files."""
110
111  def __init__(self, filenames, compression_type=None, buffer_size=None):
112    """Creates a `TextLineDataset`.
113
114    Args:
115      filenames: A `tf.string` tensor containing one or more filenames.
116      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
117        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
118      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
119        to buffer. A value of 0 results in the default buffering values chosen
120        based on the compression type.
121    """
122    self._filenames = filenames
123    self._compression_type = convert.optional_param_to_tensor(
124        "compression_type",
125        compression_type,
126        argument_default="",
127        argument_dtype=dtypes.string)
128    self._buffer_size = convert.optional_param_to_tensor(
129        "buffer_size",
130        buffer_size,
131        argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
132    variant_tensor = gen_dataset_ops.text_line_dataset(self._filenames,
133                                                       self._compression_type,
134                                                       self._buffer_size)
135    super(_TextLineDataset, self).__init__(variant_tensor)
136
137  @property
138  def element_spec(self):
139    return tensor_spec.TensorSpec([], dtypes.string)
140
141
142@tf_export("data.TextLineDataset", v1=[])
143class TextLineDatasetV2(dataset_ops.DatasetSource):
144  """A `Dataset` comprising lines from one or more text files."""
145
146  def __init__(self,
147               filenames,
148               compression_type=None,
149               buffer_size=None,
150               num_parallel_reads=None):
151    r"""Creates a `TextLineDataset`.
152
153    The elements of the dataset will be the lines of the input files, using
154    the newline character '\n' to denote line splits. The newline characters
155    will be stripped off of each element.
156
157    Args:
158      filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
159        more filenames.
160      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
161        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
162      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
163        to buffer. A value of 0 results in the default buffering values chosen
164        based on the compression type.
165      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
166        number of files to read in parallel. If greater than one, the records of
167        files read in parallel are outputted in an interleaved order. If your
168        input pipeline is I/O bottlenecked, consider setting this parameter to a
169        value greater than one to parallelize the I/O. If `None`, files will be
170        read sequentially.
171    """
172    filenames = _create_or_validate_filenames_dataset(filenames)
173    self._filenames = filenames
174    self._compression_type = compression_type
175    self._buffer_size = buffer_size
176
177    def creator_fn(filename):
178      return _TextLineDataset(filename, compression_type, buffer_size)
179
180    self._impl = _create_dataset_reader(creator_fn, filenames,
181                                        num_parallel_reads)
182    variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
183
184    super(TextLineDatasetV2, self).__init__(variant_tensor)
185
186  @property
187  def element_spec(self):
188    return tensor_spec.TensorSpec([], dtypes.string)
189
190
191@tf_export(v1=["data.TextLineDataset"])
192class TextLineDatasetV1(dataset_ops.DatasetV1Adapter):
193  """A `Dataset` comprising lines from one or more text files."""
194
195  def __init__(self,
196               filenames,
197               compression_type=None,
198               buffer_size=None,
199               num_parallel_reads=None):
200    wrapped = TextLineDatasetV2(filenames, compression_type, buffer_size,
201                                num_parallel_reads)
202    super(TextLineDatasetV1, self).__init__(wrapped)
203
204  __init__.__doc__ = TextLineDatasetV2.__init__.__doc__
205
206  @property
207  def _filenames(self):
208    return self._dataset._filenames  # pylint: disable=protected-access
209
210  @_filenames.setter
211  def _filenames(self, value):
212    self._dataset._filenames = value  # pylint: disable=protected-access
213
214
215class _TFRecordDataset(dataset_ops.DatasetSource):
216  """A `Dataset` comprising records from one or more TFRecord files."""
217
218  def __init__(self, filenames, compression_type=None, buffer_size=None):
219    """Creates a `TFRecordDataset`.
220
221    Args:
222      filenames: A `tf.string` tensor containing one or more filenames.
223      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
224        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
225      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
226        bytes in the read buffer. 0 means no buffering.
227    """
228    self._filenames = filenames
229    self._compression_type = convert.optional_param_to_tensor(
230        "compression_type",
231        compression_type,
232        argument_default="",
233        argument_dtype=dtypes.string)
234    self._buffer_size = convert.optional_param_to_tensor(
235        "buffer_size",
236        buffer_size,
237        argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
238    variant_tensor = gen_dataset_ops.tf_record_dataset(self._filenames,
239                                                       self._compression_type,
240                                                       self._buffer_size)
241    super(_TFRecordDataset, self).__init__(variant_tensor)
242
243  @property
244  def element_spec(self):
245    return tensor_spec.TensorSpec([], dtypes.string)
246
247
248class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
249  """A `Dataset` that maps a function over its input and flattens the result."""
250
251  def __init__(self, input_dataset, map_func, cycle_length, block_length,
252               sloppy, buffer_output_elements, prefetch_input_elements):
253    """See `tf.data.experimental.parallel_interleave()` for details."""
254    self._input_dataset = input_dataset
255    self._map_func = dataset_ops.StructuredFunctionWrapper(
256        map_func, self._transformation_name(), dataset=input_dataset)
257    if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec):
258      raise TypeError("`map_func` must return a `Dataset` object.")
259    self._element_spec = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
260    self._cycle_length = ops.convert_to_tensor(
261        cycle_length, dtype=dtypes.int64, name="cycle_length")
262    self._block_length = ops.convert_to_tensor(
263        block_length, dtype=dtypes.int64, name="block_length")
264    self._buffer_output_elements = convert.optional_param_to_tensor(
265        "buffer_output_elements",
266        buffer_output_elements,
267        argument_default=2 * block_length)
268    self._prefetch_input_elements = convert.optional_param_to_tensor(
269        "prefetch_input_elements",
270        prefetch_input_elements,
271        argument_default=2 * cycle_length)
272    if sloppy is None:
273      self._deterministic = "default"
274    elif sloppy:
275      self._deterministic = "false"
276    else:
277      self._deterministic = "true"
278    variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
279        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
280        self._map_func.function.captured_inputs,
281        self._cycle_length,
282        self._block_length,
283        self._buffer_output_elements,
284        self._prefetch_input_elements,
285        f=self._map_func.function,
286        deterministic=self._deterministic,
287        **self._flat_structure)
288    super(ParallelInterleaveDataset, self).__init__(input_dataset,
289                                                    variant_tensor)
290
291  def _functions(self):
292    return [self._map_func]
293
294  @property
295  def element_spec(self):
296    return self._element_spec
297
298  def _transformation_name(self):
299    return "tf.data.experimental.parallel_interleave()"
300
301
302@tf_export("data.TFRecordDataset", v1=[])
303class TFRecordDatasetV2(dataset_ops.DatasetV2):
304  """A `Dataset` comprising records from one or more TFRecord files."""
305
306  def __init__(self,
307               filenames,
308               compression_type=None,
309               buffer_size=None,
310               num_parallel_reads=None):
311    """Creates a `TFRecordDataset` to read one or more TFRecord files.
312
313    Each element of the dataset will contain a single TFRecord.
314
315    Args:
316      filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
317        more filenames.
318      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
319        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
320      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
321        bytes in the read buffer. If your input pipeline is I/O bottlenecked,
322        consider setting this parameter to a value 1-100 MBs. If `None`, a
323        sensible default for both local and remote file systems is used.
324      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
325        number of files to read in parallel. If greater than one, the records of
326        files read in parallel are outputted in an interleaved order. If your
327        input pipeline is I/O bottlenecked, consider setting this parameter to a
328        value greater than one to parallelize the I/O. If `None`, files will be
329        read sequentially.
330
331    Raises:
332      TypeError: If any argument does not have the expected type.
333      ValueError: If any argument does not have the expected shape.
334    """
335    filenames = _create_or_validate_filenames_dataset(filenames)
336
337    self._filenames = filenames
338    self._compression_type = compression_type
339    self._buffer_size = buffer_size
340    self._num_parallel_reads = num_parallel_reads
341
342    def creator_fn(filename):
343      return _TFRecordDataset(filename, compression_type, buffer_size)
344
345    self._impl = _create_dataset_reader(creator_fn, filenames,
346                                        num_parallel_reads)
347    variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
348    super(TFRecordDatasetV2, self).__init__(variant_tensor)
349
350  def _clone(self,
351             filenames=None,
352             compression_type=None,
353             buffer_size=None,
354             num_parallel_reads=None):
355    return TFRecordDatasetV2(filenames or self._filenames, compression_type or
356                             self._compression_type, buffer_size or
357                             self._buffer_size, num_parallel_reads or
358                             self._num_parallel_reads)
359
360  def _inputs(self):
361    return self._impl._inputs()  # pylint: disable=protected-access
362
363  @property
364  def element_spec(self):
365    return tensor_spec.TensorSpec([], dtypes.string)
366
367
368@tf_export(v1=["data.TFRecordDataset"])
369class TFRecordDatasetV1(dataset_ops.DatasetV1Adapter):
370  """A `Dataset` comprising records from one or more TFRecord files."""
371
372  def __init__(self,
373               filenames,
374               compression_type=None,
375               buffer_size=None,
376               num_parallel_reads=None):
377    wrapped = TFRecordDatasetV2(filenames, compression_type, buffer_size,
378                                num_parallel_reads)
379    super(TFRecordDatasetV1, self).__init__(wrapped)
380
381  __init__.__doc__ = TFRecordDatasetV2.__init__.__doc__
382
383  def _clone(self,
384             filenames=None,
385             compression_type=None,
386             buffer_size=None,
387             num_parallel_reads=None):
388    # pylint: disable=protected-access
389    return TFRecordDatasetV1(
390        filenames or self._dataset._filenames, compression_type or
391        self._dataset._compression_type, buffer_size or
392        self._dataset._buffer_size, num_parallel_reads or
393        self._dataset._num_parallel_reads)
394
395  @property
396  def _filenames(self):
397    return self._dataset._filenames  # pylint: disable=protected-access
398
399  @_filenames.setter
400  def _filenames(self, value):
401    self._dataset._filenames = value  # pylint: disable=protected-access
402
403
404class _FixedLengthRecordDataset(dataset_ops.DatasetSource):
405  """A `Dataset` of fixed-length records from one or more binary files."""
406
407  def __init__(self,
408               filenames,
409               record_bytes,
410               header_bytes=None,
411               footer_bytes=None,
412               buffer_size=None,
413               compression_type=None):
414    """Creates a `FixedLengthRecordDataset`.
415
416    Args:
417      filenames: A `tf.string` tensor containing one or more filenames.
418      record_bytes: A `tf.int64` scalar representing the number of bytes in each
419        record.
420      header_bytes: (Optional.) A `tf.int64` scalar representing the number of
421        bytes to skip at the start of a file.
422      footer_bytes: (Optional.) A `tf.int64` scalar representing the number of
423        bytes to ignore at the end of a file.
424      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
425        bytes to buffer when reading.
426      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
427        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
428    """
429    self._filenames = filenames
430    self._record_bytes = ops.convert_to_tensor(
431        record_bytes, dtype=dtypes.int64, name="record_bytes")
432    self._header_bytes = convert.optional_param_to_tensor(
433        "header_bytes", header_bytes)
434    self._footer_bytes = convert.optional_param_to_tensor(
435        "footer_bytes", footer_bytes)
436    self._buffer_size = convert.optional_param_to_tensor(
437        "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
438    self._compression_type = convert.optional_param_to_tensor(
439        "compression_type",
440        compression_type,
441        argument_default="",
442        argument_dtype=dtypes.string)
443    variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
444        self._filenames, self._header_bytes, self._record_bytes,
445        self._footer_bytes, self._buffer_size, self._compression_type)
446    super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
447
448  @property
449  def element_spec(self):
450    return tensor_spec.TensorSpec([], dtypes.string)
451
452
453@tf_export("data.FixedLengthRecordDataset", v1=[])
454class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
455  """A `Dataset` of fixed-length records from one or more binary files."""
456
457  def __init__(self,
458               filenames,
459               record_bytes,
460               header_bytes=None,
461               footer_bytes=None,
462               buffer_size=None,
463               compression_type=None,
464               num_parallel_reads=None):
465    """Creates a `FixedLengthRecordDataset`.
466
467    Args:
468      filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
469        more filenames.
470      record_bytes: A `tf.int64` scalar representing the number of bytes in each
471        record.
472      header_bytes: (Optional.) A `tf.int64` scalar representing the number of
473        bytes to skip at the start of a file.
474      footer_bytes: (Optional.) A `tf.int64` scalar representing the number of
475        bytes to ignore at the end of a file.
476      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
477        bytes to buffer when reading.
478      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
479        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
480      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
481        number of files to read in parallel. If greater than one, the records of
482        files read in parallel are outputted in an interleaved order. If your
483        input pipeline is I/O bottlenecked, consider setting this parameter to a
484        value greater than one to parallelize the I/O. If `None`, files will be
485        read sequentially.
486    """
487    filenames = _create_or_validate_filenames_dataset(filenames)
488
489    self._filenames = filenames
490    self._record_bytes = record_bytes
491    self._header_bytes = header_bytes
492    self._footer_bytes = footer_bytes
493    self._buffer_size = buffer_size
494    self._compression_type = compression_type
495
496    def creator_fn(filename):
497      return _FixedLengthRecordDataset(filename, record_bytes, header_bytes,
498                                       footer_bytes, buffer_size,
499                                       compression_type)
500
501    self._impl = _create_dataset_reader(creator_fn, filenames,
502                                        num_parallel_reads)
503    variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
504    super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
505
506  @property
507  def element_spec(self):
508    return tensor_spec.TensorSpec([], dtypes.string)
509
510
511@tf_export(v1=["data.FixedLengthRecordDataset"])
512class FixedLengthRecordDatasetV1(dataset_ops.DatasetV1Adapter):
513  """A `Dataset` of fixed-length records from one or more binary files."""
514
515  def __init__(self,
516               filenames,
517               record_bytes,
518               header_bytes=None,
519               footer_bytes=None,
520               buffer_size=None,
521               compression_type=None,
522               num_parallel_reads=None):
523    wrapped = FixedLengthRecordDatasetV2(filenames, record_bytes, header_bytes,
524                                         footer_bytes, buffer_size,
525                                         compression_type, num_parallel_reads)
526    super(FixedLengthRecordDatasetV1, self).__init__(wrapped)
527
528  __init__.__doc__ = FixedLengthRecordDatasetV2.__init__.__doc__
529
530  @property
531  def _filenames(self):
532    return self._dataset._filenames  # pylint: disable=protected-access
533
534  @_filenames.setter
535  def _filenames(self, value):
536    self._dataset._filenames = value  # pylint: disable=protected-access
537
538
539if tf2.enabled():
540  FixedLengthRecordDataset = FixedLengthRecordDatasetV2
541  TFRecordDataset = TFRecordDatasetV2
542  TextLineDataset = TextLineDatasetV2
543else:
544  FixedLengthRecordDataset = FixedLengthRecordDatasetV1
545  TFRecordDataset = TFRecordDatasetV1
546  TextLineDataset = TextLineDatasetV1
547