1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for tf.data writers."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import convert
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.ops import gen_experimental_dataset_ops
26from tensorflow.python.util.tf_export import tf_export
27
28
29@tf_export("data.experimental.TFRecordWriter")
30class TFRecordWriter(object):
31  """Writes a dataset to a TFRecord file.
32
33  The elements of the dataset must be scalar strings. To serialize dataset
34  elements as strings, you can use the `tf.io.serialize_tensor` function.
35
36  ```python
37  dataset = tf.data.Dataset.range(3)
38  dataset = dataset.map(tf.io.serialize_tensor)
39  writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord")
40  writer.write(dataset)
41  ```
42
43  To read back the elements, use `TFRecordDataset`.
44
45  ```python
46  dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord")
47  dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64))
48  ```
49
50  To shard a `dataset` across multiple TFRecord files:
51
52  ```python
53  dataset = ... # dataset to be written
54
55  def reduce_func(key, dataset):
56    filename = tf.strings.join([PATH_PREFIX, tf.strings.as_string(key)])
57    writer = tf.data.experimental.TFRecordWriter(filename)
58    writer.write(dataset.map(lambda _, x: x))
59    return tf.data.Dataset.from_tensors(filename)
60
61  dataset = dataset.enumerate()
62  dataset = dataset.apply(tf.data.experimental.group_by_window(
63    lambda i, _: i % NUM_SHARDS, reduce_func, tf.int64.max
64  ))
65  ```
66  """
67
68  def __init__(self, filename, compression_type=None):
69    """Initializes a `TFRecordWriter`.
70
71    Args:
72      filename: a string path indicating where to write the TFRecord data.
73      compression_type: (Optional.) a string indicating what type of compression
74        to use when writing the file. See `tf.io.TFRecordCompressionType` for
75        what types of compression are available. Defaults to `None`.
76    """
77    self._filename = ops.convert_to_tensor(
78        filename, dtypes.string, name="filename")
79    self._compression_type = convert.optional_param_to_tensor(
80        "compression_type",
81        compression_type,
82        argument_default="",
83        argument_dtype=dtypes.string)
84
85  def write(self, dataset):
86    """Writes a dataset to a TFRecord file.
87
88    An operation that writes the content of the specified dataset to the file
89    specified in the constructor.
90
91    If the file exists, it will be overwritten.
92
93    Args:
94      dataset: a `tf.data.Dataset` whose elements are to be written to a file
95
96    Returns:
97      In graph mode, this returns an operation which when executed performs the
98      write. In eager mode, the write is performed by the method itself and
99      there is no return value.
100
101    Raises
102      TypeError: if `dataset` is not a `tf.data.Dataset`.
103      TypeError: if the elements produced by the dataset are not scalar strings.
104    """
105    if not isinstance(dataset, dataset_ops.DatasetV2):
106      raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
107    if not dataset_ops.get_structure(dataset).is_compatible_with(
108        tensor_spec.TensorSpec([], dtypes.string)):
109      raise TypeError(
110          "`dataset` must produce scalar `DT_STRING` tensors whereas it "
111          "produces shape {0} and types {1}".format(
112              dataset_ops.get_legacy_output_shapes(dataset),
113              dataset_ops.get_legacy_output_types(dataset)))
114    return gen_experimental_dataset_ops.dataset_to_tf_record(
115        dataset._variant_tensor, self._filename, self._compression_type)  # pylint: disable=protected-access
116