1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for reader Datasets."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import csv
22import functools
23import gzip
24
25import numpy as np
26
27from tensorflow.python import tf2
28from tensorflow.python.compat import compat
29from tensorflow.python.data.experimental.ops import error_ops
30from tensorflow.python.data.experimental.ops import parsing_ops
31from tensorflow.python.data.ops import dataset_ops
32from tensorflow.python.data.ops import readers as core_readers
33from tensorflow.python.data.util import convert
34from tensorflow.python.data.util import nest
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.lib.io import file_io
41from tensorflow.python.ops import gen_experimental_dataset_ops
42from tensorflow.python.ops import io_ops
43from tensorflow.python.platform import gfile
44from tensorflow.python.util.tf_export import tf_export
45
46_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
47                         dtypes.int64, dtypes.string)
48
49
50def _is_valid_int32(str_val):
51  try:
52    # Checks equality to prevent int32 overflow
53    return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
54        str_val)
55  except (ValueError, OverflowError):
56    return False
57
58
59def _is_valid_int64(str_val):
60  try:
61    dtypes.int64.as_numpy_dtype(str_val)
62    return True
63  except (ValueError, OverflowError):
64    return False
65
66
67def _is_valid_float(str_val, float_dtype):
68  try:
69    return float_dtype.as_numpy_dtype(str_val) < np.inf
70  except ValueError:
71    return False
72
73
74def _infer_type(str_val, na_value, prev_type):
75  """Given a string, infers its tensor type.
76
77  Infers the type of a value by picking the least 'permissive' type possible,
78  while still allowing the previous type inference for this column to be valid.
79
80  Args:
81    str_val: String value to infer the type of.
82    na_value: Additional string to recognize as a NA/NaN CSV value.
83    prev_type: Type previously inferred based on values of this column that
84      we've seen up till now.
85  Returns:
86    Inferred dtype.
87  """
88  if str_val in ("", na_value):
89    # If the field is null, it gives no extra information about its type
90    return prev_type
91
92  type_list = [
93      dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
94  ]  # list of types to try, ordered from least permissive to most
95
96  type_functions = [
97      _is_valid_int32,
98      _is_valid_int64,
99      lambda str_val: _is_valid_float(str_val, dtypes.float32),
100      lambda str_val: _is_valid_float(str_val, dtypes.float64),
101      lambda str_val: True,
102  ]  # Corresponding list of validation functions
103
104  for i in range(len(type_list)):
105    validation_fn = type_functions[i]
106    if validation_fn(str_val) and (prev_type is None or
107                                   prev_type in type_list[:i + 1]):
108      return type_list[i]
109
110
111def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
112                  file_io_fn):
113  """Generator that yields rows of CSV file(s) in order."""
114  for fn in filenames:
115    with file_io_fn(fn) as f:
116      rdr = csv.reader(
117          f,
118          delimiter=field_delim,
119          quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
120      if header:
121        next(rdr)  # Skip header lines
122
123      for csv_row in rdr:
124        if len(csv_row) != num_cols:
125          raise ValueError(
126              "Problem inferring types: CSV row has different number of fields "
127              "than expected.")
128        yield csv_row
129
130
131def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
132                           na_value, header, num_rows_for_inference,
133                           select_columns, file_io_fn):
134  """Infers column types from the first N valid CSV records of files."""
135  if select_columns is None:
136    select_columns = range(num_cols)
137  inferred_types = [None] * len(select_columns)
138
139  for i, csv_row in enumerate(
140      _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
141                    file_io_fn)):
142    if num_rows_for_inference is not None and i >= num_rows_for_inference:
143      break
144
145    for j, col_index in enumerate(select_columns):
146      inferred_types[j] = _infer_type(csv_row[col_index], na_value,
147                                      inferred_types[j])
148
149  # Replace None's with a default type
150  inferred_types = [t or dtypes.string for t in inferred_types]
151  # Default to 0 or '' for null values
152  return [
153      constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
154      for t in inferred_types
155  ]
156
157
158def _infer_column_names(filenames, field_delim, use_quote_delim, file_io_fn):
159  """Infers column names from first rows of files."""
160  csv_kwargs = {
161      "delimiter": field_delim,
162      "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
163  }
164  with file_io_fn(filenames[0]) as f:
165    try:
166      column_names = next(csv.reader(f, **csv_kwargs))
167    except StopIteration:
168      raise ValueError(("Received StopIteration when reading the header line "
169                        "of %s.  Empty file?") % filenames[0])
170
171  for name in filenames[1:]:
172    with file_io_fn(name) as f:
173      try:
174        if next(csv.reader(f, **csv_kwargs)) != column_names:
175          raise ValueError(
176              "Files have different column names in the header row.")
177      except StopIteration:
178        raise ValueError(("Received StopIteration when reading the header line "
179                          "of %s.  Empty file?") % filenames[0])
180  return column_names
181
182
183def _get_sorted_col_indices(select_columns, column_names):
184  """Transforms select_columns argument into sorted column indices."""
185  names_to_indices = {n: i for i, n in enumerate(column_names)}
186  num_cols = len(column_names)
187
188  results = []
189  for v in select_columns:
190    # If value is already an int, check if it's valid.
191    if isinstance(v, int):
192      if v < 0 or v >= num_cols:
193        raise ValueError(
194            "Column index %d specified in select_columns out of valid range." %
195            v)
196      results.append(v)
197    # Otherwise, check that it's a valid column name and convert to the
198    # the relevant column index.
199    elif v not in names_to_indices:
200      raise ValueError(
201          "Value '%s' specified in select_columns not a valid column index or "
202          "name." % v)
203    else:
204      results.append(names_to_indices[v])
205
206  # Sort and ensure there are no duplicates
207  results = sorted(set(results))
208  if len(results) != len(select_columns):
209    raise ValueError("select_columns contains duplicate columns")
210  return results
211
212
213def _maybe_shuffle_and_repeat(
214    dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
215  """Optionally shuffle and repeat dataset, as requested."""
216  if shuffle:
217    dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
218  if num_epochs != 1:
219    dataset = dataset.repeat(num_epochs)
220  return dataset
221
222
223def make_tf_record_dataset(file_pattern,
224                           batch_size,
225                           parser_fn=None,
226                           num_epochs=None,
227                           shuffle=True,
228                           shuffle_buffer_size=None,
229                           shuffle_seed=None,
230                           prefetch_buffer_size=None,
231                           num_parallel_reads=None,
232                           num_parallel_parser_calls=None,
233                           drop_final_batch=False):
234  """Reads and optionally parses TFRecord files into a dataset.
235
236  Provides common functionality such as batching, optional parsing, shuffling,
237  and performant defaults.
238
239  Args:
240    file_pattern: List of files or patterns of TFRecord file paths.
241      See `tf.io.gfile.glob` for pattern rules.
242    batch_size: An int representing the number of records to combine
243      in a single batch.
244    parser_fn: (Optional.) A function accepting string input to parse
245      and process the record contents. This function must map records
246      to components of a fixed shape, so they may be batched. By
247      default, uses the record contents unmodified.
248    num_epochs: (Optional.) An int specifying the number of times this
249      dataset is repeated.  If None (the default), cycles through the
250      dataset forever.
251    shuffle: (Optional.) A bool that indicates whether the input
252      should be shuffled. Defaults to `True`.
253    shuffle_buffer_size: (Optional.) Buffer size to use for
254      shuffling. A large buffer size ensures better shuffling, but
255      increases memory usage and startup time.
256    shuffle_seed: (Optional.) Randomization seed to use for shuffling.
257    prefetch_buffer_size: (Optional.) An int specifying the number of
258      feature batches to prefetch for performance improvement.
259      Defaults to auto-tune. Set to 0 to disable prefetching.
260    num_parallel_reads: (Optional.) Number of threads used to read
261      records from files. By default or if set to a value >1, the
262      results will be interleaved. Defaults to `24`.
263    num_parallel_parser_calls: (Optional.) Number of parallel
264      records to parse in parallel. Defaults to `batch_size`.
265    drop_final_batch: (Optional.) Whether the last batch should be
266      dropped in case its size is smaller than `batch_size`; the
267      default behavior is not to drop the smaller batch.
268
269  Returns:
270    A dataset, where each element matches the output of `parser_fn`
271    except it will have an additional leading `batch-size` dimension,
272    or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
273    unspecified.
274  """
275  if num_parallel_reads is None:
276    # NOTE: We considered auto-tuning this value, but there is a concern
277    # that this affects the mixing of records from different files, which
278    # could affect training convergence/accuracy, so we are defaulting to
279    # a constant for now.
280    num_parallel_reads = 24
281
282  if num_parallel_parser_calls is None:
283    # TODO(josh11b): if num_parallel_parser_calls is None, use some function
284    # of num cores instead of `batch_size`.
285    num_parallel_parser_calls = batch_size
286
287  if prefetch_buffer_size is None:
288    prefetch_buffer_size = dataset_ops.AUTOTUNE
289
290  files = dataset_ops.Dataset.list_files(
291      file_pattern, shuffle=shuffle, seed=shuffle_seed)
292
293  dataset = core_readers.TFRecordDataset(
294      files, num_parallel_reads=num_parallel_reads)
295
296  if shuffle_buffer_size is None:
297    # TODO(josh11b): Auto-tune this value when not specified
298    shuffle_buffer_size = 10000
299  dataset = _maybe_shuffle_and_repeat(
300      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
301
302  # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
303  # improve the shape inference, because it makes the batch dimension static.
304  # It is safe to do this because in that case we are repeating the input
305  # indefinitely, and all batches will be full-sized.
306  drop_final_batch = drop_final_batch or num_epochs is None
307
308  if parser_fn is None:
309    dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
310  else:
311    dataset = dataset.map(
312        parser_fn, num_parallel_calls=num_parallel_parser_calls)
313    dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
314
315  if prefetch_buffer_size == 0:
316    return dataset
317  else:
318    return dataset.prefetch(buffer_size=prefetch_buffer_size)
319
320
321@tf_export("data.experimental.make_csv_dataset", v1=[])
322def make_csv_dataset_v2(
323    file_pattern,
324    batch_size,
325    column_names=None,
326    column_defaults=None,
327    label_name=None,
328    select_columns=None,
329    field_delim=",",
330    use_quote_delim=True,
331    na_value="",
332    header=True,
333    num_epochs=None,
334    shuffle=True,
335    shuffle_buffer_size=10000,
336    shuffle_seed=None,
337    prefetch_buffer_size=None,
338    num_parallel_reads=None,
339    sloppy=False,
340    num_rows_for_inference=100,
341    compression_type=None,
342    ignore_errors=False,
343):
344  """Reads CSV files into a dataset.
345
346  Reads CSV files into a dataset, where each element of the dataset is a
347  (features, labels) tuple that corresponds to a batch of CSV rows. The features
348  dictionary maps feature column names to `Tensor`s containing the corresponding
349  feature data, and labels is a `Tensor` containing the batch's label data.
350
351  By default, the first rows of the CSV files are expected to be headers listing
352  the column names. If the first rows are not headers, set `header=False` and
353  provide the column names with the `column_names` argument.
354
355  By default, the dataset is repeated indefinitely, reshuffling the order each
356  time. This behavior can be modified by setting the `num_epochs` and `shuffle`
357  arguments.
358
359  For example, suppose you have a CSV file containing
360
361  | Feature_A | Feature_B |
362  | --------- | --------- |
363  | 1         | "a"       |
364  | 2         | "b"       |
365  | 3         | "c"       |
366  | 4         | "d"       |
367
368  ```
369  # No label column specified
370  dataset = tf.data.experimental.make_csv_dataset(filename, batch_size=2)
371  iterator = ds.as_numpy_iterator()
372  print(dict(next(iterator)))
373  # prints a dictionary of batched features:
374  # OrderedDict([('Feature_A', array([1, 4], dtype=int32)),
375  #              ('Feature_B', array([b'a', b'd'], dtype=object))])
376  ```
377
378  ```
379  # Set Feature_B as label column
380  dataset = tf.data.experimental.make_csv_dataset(
381      filename, batch_size=2, label_name="Feature_B")
382  iterator = ds.as_numpy_iterator()
383  print(next(iterator))
384  # prints (features, labels) tuple:
385  # (OrderedDict([('Feature_A', array([1, 2], dtype=int32))]),
386  #  array([b'a', b'b'], dtype=object))
387  ```
388
389  See the
390  [Load CSV data guide](https://www.tensorflow.org/tutorials/load_data/csv) for
391  more examples of using `make_csv_dataset` to read CSV data.
392
393  Args:
394    file_pattern: List of files or patterns of file paths containing CSV
395      records. See `tf.io.gfile.glob` for pattern rules.
396    batch_size: An int representing the number of records to combine
397      in a single batch.
398    column_names: An optional list of strings that corresponds to the CSV
399      columns, in order. One per column of the input record. If this is not
400      provided, infers the column names from the first row of the records.
401      These names will be the keys of the features dict of each dataset element.
402    column_defaults: A optional list of default values for the CSV fields. One
403      item per selected column of the input record. Each item in the list is
404      either a valid CSV dtype (float32, float64, int32, int64, or string), or a
405      `Tensor` with one of the aforementioned types. The tensor can either be
406      a scalar default value (if the column is optional), or an empty tensor (if
407      the column is required). If a dtype is provided instead of a tensor, the
408      column is also treated as required. If this list is not provided, tries
409      to infer types based on reading the first num_rows_for_inference rows of
410      files specified, and assumes all columns are optional, defaulting to `0`
411      for numeric values and `""` for string values. If both this and
412      `select_columns` are specified, these must have the same lengths, and
413      `column_defaults` is assumed to be sorted in order of increasing column
414      index.
415    label_name: A optional string corresponding to the label column. If
416      provided, the data for this column is returned as a separate `Tensor` from
417      the features dictionary, so that the dataset complies with the format
418      expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
419      function.
420    select_columns: An optional list of integer indices or string column
421      names, that specifies a subset of columns of CSV data to select. If
422      column names are provided, these must correspond to names provided in
423      `column_names` or inferred from the file header lines. When this argument
424      is specified, only a subset of CSV columns will be parsed and returned,
425      corresponding to the columns specified. Using this results in faster
426      parsing and lower memory usage. If both this and `column_defaults` are
427      specified, these must have the same lengths, and `column_defaults` is
428      assumed to be sorted in order of increasing column index.
429    field_delim: An optional `string`. Defaults to `","`. Char delimiter to
430      separate fields in a record.
431    use_quote_delim: An optional bool. Defaults to `True`. If false, treats
432      double quotation marks as regular characters inside of the string fields.
433    na_value: Additional string to recognize as NA/NaN.
434    header: A bool that indicates whether the first rows of provided CSV files
435      correspond to header lines with column names, and should not be included
436      in the data.
437    num_epochs: An int specifying the number of times this dataset is repeated.
438      If None, cycles through the dataset forever.
439    shuffle: A bool that indicates whether the input should be shuffled.
440    shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
441      ensures better shuffling, but increases memory usage and startup time.
442    shuffle_seed: Randomization seed to use for shuffling.
443    prefetch_buffer_size: An int specifying the number of feature
444      batches to prefetch for performance improvement. Recommended value is the
445      number of batches consumed per training step. Defaults to auto-tune.
446    num_parallel_reads: Number of threads used to read CSV records from files.
447      If >1, the results will be interleaved. Defaults to `1`.
448    sloppy: If `True`, reading performance will be improved at
449      the cost of non-deterministic ordering. If `False`, the order of elements
450      produced is deterministic prior to shuffling (elements are still
451      randomized if `shuffle=True`. Note that if the seed is set, then order
452      of elements after shuffling is deterministic). Defaults to `False`.
453    num_rows_for_inference: Number of rows of a file to use for type inference
454      if record_defaults is not provided. If None, reads all the rows of all
455      the files. Defaults to 100.
456    compression_type: (Optional.) A `tf.string` scalar evaluating to one of
457      `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression.
458    ignore_errors: (Optional.) If `True`, ignores errors with CSV file parsing,
459      such as malformed data or empty lines, and moves on to the next valid
460      CSV record. Otherwise, the dataset raises an error and stops processing
461      when encountering any invalid records. Defaults to `False`.
462
463  Returns:
464    A dataset, where each element is a (features, labels) tuple that corresponds
465    to a batch of `batch_size` CSV rows. The features dictionary maps feature
466    column names to `Tensor`s containing the corresponding column data, and
467    labels is a `Tensor` containing the column data for the label column
468    specified by `label_name`.
469
470  Raises:
471    ValueError: If any of the arguments is malformed.
472  """
473  if num_parallel_reads is None:
474    num_parallel_reads = 1
475
476  if prefetch_buffer_size is None:
477    prefetch_buffer_size = dataset_ops.AUTOTUNE
478
479  # Create dataset of all matching filenames
480  filenames = _get_file_names(file_pattern, False)
481  dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
482  if shuffle:
483    dataset = dataset.shuffle(len(filenames), shuffle_seed)
484
485  # Clean arguments; figure out column names and defaults
486  if column_names is None or column_defaults is None:
487    # Find out which io function to open the file
488    file_io_fn = lambda filename: file_io.FileIO(filename, "r")
489    if compression_type is not None:
490      compression_type_value = tensor_util.constant_value(compression_type)
491      if compression_type_value is None:
492        raise ValueError("Received unknown compression_type")
493      if compression_type_value == "GZIP":
494        file_io_fn = lambda filename: gzip.open(filename, "rt")
495      elif compression_type_value == "ZLIB":
496        raise ValueError(
497            "compression_type (%s) is not supported for probing columns" %
498            compression_type)
499      elif compression_type_value != "":
500        raise ValueError("compression_type (%s) is not supported" %
501                         compression_type)
502  if column_names is None:
503    if not header:
504      raise ValueError("Cannot infer column names without a header line.")
505    # If column names are not provided, infer from the header lines
506    column_names = _infer_column_names(filenames, field_delim, use_quote_delim,
507                                       file_io_fn)
508  if len(column_names) != len(set(column_names)):
509    raise ValueError("Cannot have duplicate column names.")
510
511  if select_columns is not None:
512    select_columns = _get_sorted_col_indices(select_columns, column_names)
513
514  if column_defaults is not None:
515    column_defaults = [
516        constant_op.constant([], dtype=x)
517        if not tensor_util.is_tf_type(x) and x in _ACCEPTABLE_CSV_TYPES else x
518        for x in column_defaults
519    ]
520  else:
521    # If column defaults are not provided, infer from records at graph
522    # construction time
523    column_defaults = _infer_column_defaults(filenames, len(column_names),
524                                             field_delim, use_quote_delim,
525                                             na_value, header,
526                                             num_rows_for_inference,
527                                             select_columns, file_io_fn)
528
529  if select_columns is not None and len(column_defaults) != len(select_columns):
530    raise ValueError(
531        "If specified, column_defaults and select_columns must have same "
532        "length."
533    )
534  if select_columns is not None and len(column_names) > len(select_columns):
535    # Pick the relevant subset of column names
536    column_names = [column_names[i] for i in select_columns]
537
538  if label_name is not None and label_name not in column_names:
539    raise ValueError("`label_name` provided must be one of the columns.")
540
541  def filename_to_dataset(filename):
542    dataset = CsvDataset(
543        filename,
544        record_defaults=column_defaults,
545        field_delim=field_delim,
546        use_quote_delim=use_quote_delim,
547        na_value=na_value,
548        select_cols=select_columns,
549        header=header,
550        compression_type=compression_type
551    )
552    if ignore_errors:
553      dataset = dataset.apply(error_ops.ignore_errors())
554    return dataset
555
556  def map_fn(*columns):
557    """Organizes columns into a features dictionary.
558
559    Args:
560      *columns: list of `Tensor`s corresponding to one csv record.
561    Returns:
562      An OrderedDict of feature names to values for that particular record. If
563      label_name is provided, extracts the label feature to be returned as the
564      second element of the tuple.
565    """
566    features = collections.OrderedDict(zip(column_names, columns))
567    if label_name is not None:
568      label = features.pop(label_name)
569      return features, label
570    return features
571
572  if num_parallel_reads == dataset_ops.AUTOTUNE:
573    dataset = dataset.interleave(
574        filename_to_dataset, num_parallel_calls=num_parallel_reads)
575    options = dataset_ops.Options()
576    options.experimental_deterministic = not sloppy
577    dataset = dataset.with_options(options)
578  else:
579    # Read files sequentially (if num_parallel_reads=1) or in parallel
580    def apply_fn(dataset):
581      return core_readers.ParallelInterleaveDataset(
582          dataset,
583          filename_to_dataset,
584          cycle_length=num_parallel_reads,
585          block_length=1,
586          sloppy=sloppy,
587          buffer_output_elements=None,
588          prefetch_input_elements=None)
589
590    dataset = dataset.apply(apply_fn)
591
592  dataset = _maybe_shuffle_and_repeat(
593      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
594
595  # Apply batch before map for perf, because map has high overhead relative
596  # to the size of the computation in each map.
597  # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
598  # improve the shape inference, because it makes the batch dimension static.
599  # It is safe to do this because in that case we are repeating the input
600  # indefinitely, and all batches will be full-sized.
601  dataset = dataset.batch(batch_size=batch_size,
602                          drop_remainder=num_epochs is None)
603  dataset = dataset_ops.MapDataset(
604      dataset, map_fn, use_inter_op_parallelism=False)
605  dataset = dataset.prefetch(prefetch_buffer_size)
606
607  return dataset
608
609
610@tf_export(v1=["data.experimental.make_csv_dataset"])
611def make_csv_dataset_v1(
612    file_pattern,
613    batch_size,
614    column_names=None,
615    column_defaults=None,
616    label_name=None,
617    select_columns=None,
618    field_delim=",",
619    use_quote_delim=True,
620    na_value="",
621    header=True,
622    num_epochs=None,
623    shuffle=True,
624    shuffle_buffer_size=10000,
625    shuffle_seed=None,
626    prefetch_buffer_size=None,
627    num_parallel_reads=None,
628    sloppy=False,
629    num_rows_for_inference=100,
630    compression_type=None,
631    ignore_errors=False,
632):  # pylint: disable=missing-docstring
633  return dataset_ops.DatasetV1Adapter(make_csv_dataset_v2(
634      file_pattern, batch_size, column_names, column_defaults, label_name,
635      select_columns, field_delim, use_quote_delim, na_value, header,
636      num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
637      prefetch_buffer_size, num_parallel_reads, sloppy, num_rows_for_inference,
638      compression_type, ignore_errors))
639make_csv_dataset_v1.__doc__ = make_csv_dataset_v2.__doc__
640
641
642_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024  # 4 MB
643
644
645@tf_export("data.experimental.CsvDataset", v1=[])
646class CsvDatasetV2(dataset_ops.DatasetSource):
647  """A Dataset comprising lines from one or more CSV files."""
648
649  def __init__(self,
650               filenames,
651               record_defaults,
652               compression_type=None,
653               buffer_size=None,
654               header=False,
655               field_delim=",",
656               use_quote_delim=True,
657               na_value="",
658               select_cols=None,
659               exclude_cols=None):
660    """Creates a `CsvDataset` by reading and decoding CSV files.
661
662    The elements of this dataset correspond to records from the file(s).
663    RFC 4180 format is expected for CSV files
664    (https://tools.ietf.org/html/rfc4180)
665    Note that we allow leading and trailing spaces with int or float field.
666
667
668    For example, suppose we have a file 'my_file0.csv' with four CSV columns of
669    different data types:
670    ```
671    abcdefg,4.28E10,5.55E6,12
672    hijklmn,-5.3E14,,2
673    ```
674
675    We can construct a CsvDataset from it as follows:
676
677    ```python
678     dataset = tf.data.experimental.CsvDataset(
679        "my_file*.csv",
680        [tf.float32,  # Required field, use dtype or empty tensor
681         tf.constant([0.0], dtype=tf.float32),  # Optional field, default to 0.0
682         tf.int32,  # Required field, use dtype or empty tensor
683         ],
684        select_cols=[1,2,3]  # Only parse last three columns
685    )
686    ```
687
688    The expected output of its iterations is:
689
690    ```python
691    for element in dataset:
692      print(element)
693
694    >> (4.28e10, 5.55e6, 12)
695    >> (-5.3e14, 0.0, 2)
696    ```
697
698
699    Alternatively, suppose we have a CSV file of floats with 200 columns,
700    and we want to use all columns besides the first. We can construct a
701    CsvDataset from it as follows:
702
703    ```python
704    dataset = tf.data.experimental.CsvDataset(
705        "my_file.csv",
706        [tf.float32] * 199,  # Parse 199 required columns as floats
707        exclude_cols=[0]  # Parse all columns except the first
708    )
709    ```
710
711    Args:
712      filenames: A `tf.string` tensor containing one or more filenames.
713      record_defaults: A list of default values for the CSV fields. Each item in
714        the list is either a valid CSV `DType` (float32, float64, int32, int64,
715        string), or a `Tensor` object with one of the above types. One per
716        column of CSV data, with either a scalar `Tensor` default value for the
717        column if it is optional, or `DType` or empty `Tensor` if required. If
718        both this and `select_columns` are specified, these must have the same
719        lengths, and `column_defaults` is assumed to be sorted in order of
720        increasing column index. If both this and 'exclude_cols' are specified,
721        the sum of lengths of record_defaults and exclude_cols should equal
722        the total number of columns in the CSV file.
723      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
724        `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
725        compression.
726      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
727        to buffer while reading files. Defaults to 4MB.
728      header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
729        have header line(s) that should be skipped when parsing. Defaults to
730        `False`.
731      field_delim: (Optional.) A `tf.string` scalar containing the delimiter
732        character that separates fields in a record. Defaults to `","`.
733      use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
734        double quotation marks as regular characters inside of string fields
735        (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
736      na_value: (Optional.) A `tf.string` scalar indicating a value that will
737        be treated as NA/NaN.
738      select_cols: (Optional.) A sorted list of column indices to select from
739        the input data. If specified, only this subset of columns will be
740        parsed. Defaults to parsing all columns. At most one of `select_cols`
741        and `exclude_cols` can be specified.
742      exclude_cols: (Optional.) A sorted list of column indices to exclude from
743        the input data. If specified, only the complement of this set of column
744        will be parsed. Defaults to parsing all columns. At most one of
745        `select_cols` and `exclude_cols` can be specified.
746
747    Raises:
748       InvalidArgumentError: If exclude_cols is not None and
749           len(exclude_cols) + len(record_defaults) does not match the total
750           number of columns in the file(s)
751
752
753    """
754    self._filenames = ops.convert_to_tensor(
755        filenames, dtype=dtypes.string, name="filenames")
756    self._compression_type = convert.optional_param_to_tensor(
757        "compression_type",
758        compression_type,
759        argument_default="",
760        argument_dtype=dtypes.string)
761    record_defaults = [
762        constant_op.constant([], dtype=x)
763        if not tensor_util.is_tf_type(x) and x in _ACCEPTABLE_CSV_TYPES else x
764        for x in record_defaults
765    ]
766    self._record_defaults = ops.convert_n_to_tensor(
767        record_defaults, name="record_defaults")
768    self._buffer_size = convert.optional_param_to_tensor(
769        "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
770    self._header = ops.convert_to_tensor(
771        header, dtype=dtypes.bool, name="header")
772    self._field_delim = ops.convert_to_tensor(
773        field_delim, dtype=dtypes.string, name="field_delim")
774    self._use_quote_delim = ops.convert_to_tensor(
775        use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
776    self._na_value = ops.convert_to_tensor(
777        na_value, dtype=dtypes.string, name="na_value")
778    self._select_cols = convert.optional_param_to_tensor(
779        "select_cols",
780        select_cols,
781        argument_default=[],
782        argument_dtype=dtypes.int64,
783    )
784    self._exclude_cols = convert.optional_param_to_tensor(
785        "exclude_cols",
786        exclude_cols,
787        argument_default=[],
788        argument_dtype=dtypes.int64,
789    )
790    self._element_spec = tuple(
791        tensor_spec.TensorSpec([], d.dtype) for d in self._record_defaults)
792    if compat.forward_compatible(2020, 7, 3) or exclude_cols is not None:
793      variant_tensor = gen_experimental_dataset_ops.csv_dataset_v2(
794          filenames=self._filenames,
795          record_defaults=self._record_defaults,
796          buffer_size=self._buffer_size,
797          header=self._header,
798          output_shapes=self._flat_shapes,
799          field_delim=self._field_delim,
800          use_quote_delim=self._use_quote_delim,
801          na_value=self._na_value,
802          select_cols=self._select_cols,
803          exclude_cols=self._exclude_cols,
804          compression_type=self._compression_type)
805    else:
806      variant_tensor = gen_experimental_dataset_ops.csv_dataset(
807          filenames=self._filenames,
808          record_defaults=self._record_defaults,
809          buffer_size=self._buffer_size,
810          header=self._header,
811          output_shapes=self._flat_shapes,
812          field_delim=self._field_delim,
813          use_quote_delim=self._use_quote_delim,
814          na_value=self._na_value,
815          select_cols=self._select_cols,
816          compression_type=self._compression_type)
817    super(CsvDatasetV2, self).__init__(variant_tensor)
818
819  @property
820  def element_spec(self):
821    return self._element_spec
822
823
824@tf_export(v1=["data.experimental.CsvDataset"])
825class CsvDatasetV1(dataset_ops.DatasetV1Adapter):
826  """A Dataset comprising lines from one or more CSV files."""
827
828  @functools.wraps(CsvDatasetV2.__init__, ("__module__", "__name__"))
829  def __init__(self,
830               filenames,
831               record_defaults,
832               compression_type=None,
833               buffer_size=None,
834               header=False,
835               field_delim=",",
836               use_quote_delim=True,
837               na_value="",
838               select_cols=None):
839    """Creates a `CsvDataset` by reading and decoding CSV files.
840
841    The elements of this dataset correspond to records from the file(s).
842    RFC 4180 format is expected for CSV files
843    (https://tools.ietf.org/html/rfc4180)
844    Note that we allow leading and trailing spaces with int or float field.
845
846
847    For example, suppose we have a file 'my_file0.csv' with four CSV columns of
848    different data types:
849    ```
850    abcdefg,4.28E10,5.55E6,12
851    hijklmn,-5.3E14,,2
852    ```
853
854    We can construct a CsvDataset from it as follows:
855
856    ```python
857     dataset = tf.data.experimental.CsvDataset(
858        "my_file*.csv",
859        [tf.float32,  # Required field, use dtype or empty tensor
860         tf.constant([0.0], dtype=tf.float32),  # Optional field, default to 0.0
861         tf.int32,  # Required field, use dtype or empty tensor
862         ],
863        select_cols=[1,2,3]  # Only parse last three columns
864    )
865    ```
866
867    The expected output of its iterations is:
868
869    ```python
870    for element in dataset:
871      print(element)
872
873    >> (4.28e10, 5.55e6, 12)
874    >> (-5.3e14, 0.0, 2)
875    ```
876
877    Args:
878      filenames: A `tf.string` tensor containing one or more filenames.
879      record_defaults: A list of default values for the CSV fields. Each item in
880        the list is either a valid CSV `DType` (float32, float64, int32, int64,
881        string), or a `Tensor` object with one of the above types. One per
882        column of CSV data, with either a scalar `Tensor` default value for the
883        column if it is optional, or `DType` or empty `Tensor` if required. If
884        both this and `select_columns` are specified, these must have the same
885        lengths, and `column_defaults` is assumed to be sorted in order of
886        increasing column index. If both this and 'exclude_cols' are specified,
887        the sum of lengths of record_defaults and exclude_cols should equal the
888        total number of columns in the CSV file.
889      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
890        `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
891        compression.
892      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
893        to buffer while reading files. Defaults to 4MB.
894      header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
895        have header line(s) that should be skipped when parsing. Defaults to
896        `False`.
897      field_delim: (Optional.) A `tf.string` scalar containing the delimiter
898        character that separates fields in a record. Defaults to `","`.
899      use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats double
900        quotation marks as regular characters inside of string fields (ignoring
901        RFC 4180, Section 2, Bullet 5). Defaults to `True`.
902      na_value: (Optional.) A `tf.string` scalar indicating a value that will be
903        treated as NA/NaN.
904      select_cols: (Optional.) A sorted list of column indices to select from
905        the input data. If specified, only this subset of columns will be
906        parsed. Defaults to parsing all columns. At most one of `select_cols`
907        and `exclude_cols` can be specified.
908    """
909    wrapped = CsvDatasetV2(filenames, record_defaults, compression_type,
910                           buffer_size, header, field_delim, use_quote_delim,
911                           na_value, select_cols)
912    super(CsvDatasetV1, self).__init__(wrapped)
913
914
915@tf_export("data.experimental.make_batched_features_dataset", v1=[])
916def make_batched_features_dataset_v2(file_pattern,
917                                     batch_size,
918                                     features,
919                                     reader=None,
920                                     label_key=None,
921                                     reader_args=None,
922                                     num_epochs=None,
923                                     shuffle=True,
924                                     shuffle_buffer_size=10000,
925                                     shuffle_seed=None,
926                                     prefetch_buffer_size=None,
927                                     reader_num_threads=None,
928                                     parser_num_threads=None,
929                                     sloppy_ordering=False,
930                                     drop_final_batch=False):
931  """Returns a `Dataset` of feature dictionaries from `Example` protos.
932
933  If label_key argument is provided, returns a `Dataset` of tuple
934  comprising of feature dictionaries and label.
935
936  Example:
937
938  ```
939  serialized_examples = [
940    features {
941      feature { key: "age" value { int64_list { value: [ 0 ] } } }
942      feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
943      feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
944    },
945    features {
946      feature { key: "age" value { int64_list { value: [] } } }
947      feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
948      feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
949    }
950  ]
951  ```
952
953  We can use arguments:
954
955  ```
956  features: {
957    "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
958    "gender": FixedLenFeature([], dtype=tf.string),
959    "kws": VarLenFeature(dtype=tf.string),
960  }
961  ```
962
963  And the expected output is:
964
965  ```python
966  {
967    "age": [[0], [-1]],
968    "gender": [["f"], ["f"]],
969    "kws": SparseTensor(
970      indices=[[0, 0], [0, 1], [1, 0]],
971      values=["code", "art", "sports"]
972      dense_shape=[2, 2]),
973  }
974  ```
975
976  Args:
977    file_pattern: List of files or patterns of file paths containing
978      `Example` records. See `tf.io.gfile.glob` for pattern rules.
979    batch_size: An int representing the number of records to combine
980      in a single batch.
981    features: A `dict` mapping feature keys to `FixedLenFeature` or
982      `VarLenFeature` values. See `tf.io.parse_example`.
983    reader: A function or class that can be
984      called with a `filenames` tensor and (optional) `reader_args` and returns
985      a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
986    label_key: (Optional) A string corresponding to the key labels are stored in
987      `tf.Examples`. If provided, it must be one of the `features` key,
988      otherwise results in `ValueError`.
989    reader_args: Additional arguments to pass to the reader class.
990    num_epochs: Integer specifying the number of times to read through the
991      dataset. If None, cycles through the dataset forever. Defaults to `None`.
992    shuffle: A boolean, indicates whether the input should be shuffled. Defaults
993      to `True`.
994    shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
995      ensures better shuffling but would increase memory usage and startup time.
996    shuffle_seed: Randomization seed to use for shuffling.
997    prefetch_buffer_size: Number of feature batches to prefetch in order to
998      improve performance. Recommended value is the number of batches consumed
999      per training step. Defaults to auto-tune.
1000    reader_num_threads: Number of threads used to read `Example` records. If >1,
1001      the results will be interleaved. Defaults to `1`.
1002    parser_num_threads: Number of threads to use for parsing `Example` tensors
1003      into a dictionary of `Feature` tensors. Defaults to `2`.
1004    sloppy_ordering: If `True`, reading performance will be improved at
1005      the cost of non-deterministic ordering. If `False`, the order of elements
1006      produced is deterministic prior to shuffling (elements are still
1007      randomized if `shuffle=True`. Note that if the seed is set, then order
1008      of elements after shuffling is deterministic). Defaults to `False`.
1009    drop_final_batch: If `True`, and the batch size does not evenly divide the
1010      input dataset size, the final smaller batch will be dropped. Defaults to
1011      `False`.
1012
1013  Returns:
1014    A dataset of `dict` elements, (or a tuple of `dict` elements and label).
1015    Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
1016
1017  Raises:
1018    TypeError: If `reader` is of the wrong type.
1019    ValueError: If `label_key` is not one of the `features` keys.
1020  """
1021  if reader is None:
1022    reader = core_readers.TFRecordDataset
1023
1024  if reader_num_threads is None:
1025    reader_num_threads = 1
1026  if parser_num_threads is None:
1027    parser_num_threads = 2
1028  if prefetch_buffer_size is None:
1029    prefetch_buffer_size = dataset_ops.AUTOTUNE
1030
1031  # Create dataset of all matching filenames
1032  dataset = dataset_ops.Dataset.list_files(
1033      file_pattern, shuffle=shuffle, seed=shuffle_seed)
1034
1035  if isinstance(reader, type) and issubclass(reader, io_ops.ReaderBase):
1036    raise TypeError("The `reader` argument must return a `Dataset` object. "
1037                    "`tf.ReaderBase` subclasses are not supported. For "
1038                    "example, pass `tf.data.TFRecordDataset` instead of "
1039                    "`tf.TFRecordReader`.")
1040
1041  # Read `Example` records from files as tensor objects.
1042  if reader_args is None:
1043    reader_args = []
1044
1045  if reader_num_threads == dataset_ops.AUTOTUNE:
1046    dataset = dataset.interleave(
1047        lambda filename: reader(filename, *reader_args),
1048        num_parallel_calls=reader_num_threads)
1049    options = dataset_ops.Options()
1050    options.experimental_deterministic = not sloppy_ordering
1051    dataset = dataset.with_options(options)
1052  else:
1053    # Read files sequentially (if reader_num_threads=1) or in parallel
1054    def apply_fn(dataset):
1055      return core_readers.ParallelInterleaveDataset(
1056          dataset,
1057          lambda filename: reader(filename, *reader_args),
1058          cycle_length=reader_num_threads,
1059          block_length=1,
1060          sloppy=sloppy_ordering,
1061          buffer_output_elements=None,
1062          prefetch_input_elements=None)
1063
1064    dataset = dataset.apply(apply_fn)
1065
1066  # Extract values if the `Example` tensors are stored as key-value tuples.
1067  if dataset_ops.get_legacy_output_types(dataset) == (
1068      dtypes.string, dtypes.string):
1069    dataset = dataset_ops.MapDataset(
1070        dataset, lambda _, v: v, use_inter_op_parallelism=False)
1071
1072  # Apply dataset repeat and shuffle transformations.
1073  dataset = _maybe_shuffle_and_repeat(
1074      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
1075
1076  # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
1077  # improve the shape inference, because it makes the batch dimension static.
1078  # It is safe to do this because in that case we are repeating the input
1079  # indefinitely, and all batches will be full-sized.
1080  dataset = dataset.batch(
1081      batch_size, drop_remainder=drop_final_batch or num_epochs is None)
1082
1083  # Parse `Example` tensors to a dictionary of `Feature` tensors.
1084  dataset = dataset.apply(
1085      parsing_ops.parse_example_dataset(
1086          features, num_parallel_calls=parser_num_threads))
1087
1088  if label_key:
1089    if label_key not in features:
1090      raise ValueError(
1091          "The `label_key` provided (%r) must be one of the `features` keys." %
1092          label_key)
1093    dataset = dataset.map(lambda x: (x, x.pop(label_key)))
1094
1095  dataset = dataset.prefetch(prefetch_buffer_size)
1096  return dataset
1097
1098
1099@tf_export(v1=["data.experimental.make_batched_features_dataset"])
1100def make_batched_features_dataset_v1(file_pattern,  # pylint: disable=missing-docstring
1101                                     batch_size,
1102                                     features,
1103                                     reader=None,
1104                                     label_key=None,
1105                                     reader_args=None,
1106                                     num_epochs=None,
1107                                     shuffle=True,
1108                                     shuffle_buffer_size=10000,
1109                                     shuffle_seed=None,
1110                                     prefetch_buffer_size=None,
1111                                     reader_num_threads=None,
1112                                     parser_num_threads=None,
1113                                     sloppy_ordering=False,
1114                                     drop_final_batch=False):
1115  return dataset_ops.DatasetV1Adapter(make_batched_features_dataset_v2(
1116      file_pattern, batch_size, features, reader, label_key, reader_args,
1117      num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
1118      prefetch_buffer_size, reader_num_threads, parser_num_threads,
1119      sloppy_ordering, drop_final_batch))
1120make_batched_features_dataset_v1.__doc__ = (
1121    make_batched_features_dataset_v2.__doc__)
1122
1123
1124def _get_file_names(file_pattern, shuffle):
1125  """Parse list of file names from pattern, optionally shuffled.
1126
1127  Args:
1128    file_pattern: File glob pattern, or list of glob patterns.
1129    shuffle: Whether to shuffle the order of file names.
1130
1131  Returns:
1132    List of file names matching `file_pattern`.
1133
1134  Raises:
1135    ValueError: If `file_pattern` is empty, or pattern matches no files.
1136  """
1137  if isinstance(file_pattern, list):
1138    if not file_pattern:
1139      raise ValueError("File pattern is empty.")
1140    file_names = []
1141    for entry in file_pattern:
1142      file_names.extend(gfile.Glob(entry))
1143  else:
1144    file_names = list(gfile.Glob(file_pattern))
1145
1146  if not file_names:
1147    raise ValueError("No files match %s." % file_pattern)
1148
1149  # Sort files so it will be deterministic for unit tests.
1150  if not shuffle:
1151    file_names = sorted(file_names)
1152  return file_names
1153
1154
1155@tf_export("data.experimental.SqlDataset", v1=[])
1156class SqlDatasetV2(dataset_ops.DatasetSource):
1157  """A `Dataset` consisting of the results from a SQL query."""
1158
1159  def __init__(self, driver_name, data_source_name, query, output_types):
1160    """Creates a `SqlDataset`.
1161
1162    `SqlDataset` allows a user to read data from the result set of a SQL query.
1163    For example:
1164
1165    ```python
1166    dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3",
1167                                              "SELECT name, age FROM people",
1168                                              (tf.string, tf.int32))
1169    # Prints the rows of the result set of the above query.
1170    for element in dataset:
1171      print(element)
1172    ```
1173
1174    Args:
1175      driver_name: A 0-D `tf.string` tensor containing the database type.
1176        Currently, the only supported value is 'sqlite'.
1177      data_source_name: A 0-D `tf.string` tensor containing a connection string
1178        to connect to the database.
1179      query: A 0-D `tf.string` tensor containing the SQL query to execute.
1180      output_types: A tuple of `tf.DType` objects representing the types of the
1181        columns returned by `query`.
1182    """
1183    self._driver_name = ops.convert_to_tensor(
1184        driver_name, dtype=dtypes.string, name="driver_name")
1185    self._data_source_name = ops.convert_to_tensor(
1186        data_source_name, dtype=dtypes.string, name="data_source_name")
1187    self._query = ops.convert_to_tensor(
1188        query, dtype=dtypes.string, name="query")
1189    self._element_spec = nest.map_structure(
1190        lambda dtype: tensor_spec.TensorSpec([], dtype), output_types)
1191    variant_tensor = gen_experimental_dataset_ops.sql_dataset(
1192        self._driver_name, self._data_source_name, self._query,
1193        **self._flat_structure)
1194    super(SqlDatasetV2, self).__init__(variant_tensor)
1195
1196  @property
1197  def element_spec(self):
1198    return self._element_spec
1199
1200
1201@tf_export(v1=["data.experimental.SqlDataset"])
1202class SqlDatasetV1(dataset_ops.DatasetV1Adapter):
1203  """A `Dataset` consisting of the results from a SQL query."""
1204
1205  @functools.wraps(SqlDatasetV2.__init__)
1206  def __init__(self, driver_name, data_source_name, query, output_types):
1207    wrapped = SqlDatasetV2(driver_name, data_source_name, query, output_types)
1208    super(SqlDatasetV1, self).__init__(wrapped)
1209
1210
1211if tf2.enabled():
1212  CsvDataset = CsvDatasetV2
1213  SqlDataset = SqlDatasetV2
1214  make_batched_features_dataset = make_batched_features_dataset_v2
1215  make_csv_dataset = make_csv_dataset_v2
1216else:
1217  CsvDataset = CsvDatasetV1
1218  SqlDataset = SqlDatasetV1
1219  make_batched_features_dataset = make_batched_features_dataset_v1
1220  make_csv_dataset = make_csv_dataset_v1
1221