1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for tf.data writers.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import convert 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_spec 25from tensorflow.python.ops import gen_experimental_dataset_ops 26from tensorflow.python.util.tf_export import tf_export 27 28 29@tf_export("data.experimental.TFRecordWriter") 30class TFRecordWriter(object): 31 """Writes a dataset to a TFRecord file. 32 33 The elements of the dataset must be scalar strings. To serialize dataset 34 elements as strings, you can use the `tf.io.serialize_tensor` function. 35 36 ```python 37 dataset = tf.data.Dataset.range(3) 38 dataset = dataset.map(tf.io.serialize_tensor) 39 writer = tf.data.experimental.TFRecordWriter("/path/to/file.tfrecord") 40 writer.write(dataset) 41 ``` 42 43 To read back the elements, use `TFRecordDataset`. 44 45 ```python 46 dataset = tf.data.TFRecordDataset("/path/to/file.tfrecord") 47 dataset = dataset.map(lambda x: tf.io.parse_tensor(x, tf.int64)) 48 ``` 49 50 To shard a `dataset` across multiple TFRecord files: 51 52 ```python 53 dataset = ... # dataset to be written 54 55 def reduce_func(key, dataset): 56 filename = tf.strings.join([PATH_PREFIX, tf.strings.as_string(key)]) 57 writer = tf.data.experimental.TFRecordWriter(filename) 58 writer.write(dataset.map(lambda _, x: x)) 59 return tf.data.Dataset.from_tensors(filename) 60 61 dataset = dataset.enumerate() 62 dataset = dataset.apply(tf.data.experimental.group_by_window( 63 lambda i, _: i % NUM_SHARDS, reduce_func, tf.int64.max 64 )) 65 ``` 66 """ 67 68 def __init__(self, filename, compression_type=None): 69 """Initializes a `TFRecordWriter`. 70 71 Args: 72 filename: a string path indicating where to write the TFRecord data. 73 compression_type: (Optional.) a string indicating what type of compression 74 to use when writing the file. See `tf.io.TFRecordCompressionType` for 75 what types of compression are available. Defaults to `None`. 76 """ 77 self._filename = ops.convert_to_tensor( 78 filename, dtypes.string, name="filename") 79 self._compression_type = convert.optional_param_to_tensor( 80 "compression_type", 81 compression_type, 82 argument_default="", 83 argument_dtype=dtypes.string) 84 85 def write(self, dataset): 86 """Writes a dataset to a TFRecord file. 87 88 An operation that writes the content of the specified dataset to the file 89 specified in the constructor. 90 91 If the file exists, it will be overwritten. 92 93 Args: 94 dataset: a `tf.data.Dataset` whose elements are to be written to a file 95 96 Returns: 97 In graph mode, this returns an operation which when executed performs the 98 write. In eager mode, the write is performed by the method itself and 99 there is no return value. 100 101 Raises 102 TypeError: if `dataset` is not a `tf.data.Dataset`. 103 TypeError: if the elements produced by the dataset are not scalar strings. 104 """ 105 if not isinstance(dataset, dataset_ops.DatasetV2): 106 raise TypeError("`dataset` must be a `tf.data.Dataset` object.") 107 if not dataset_ops.get_structure(dataset).is_compatible_with( 108 tensor_spec.TensorSpec([], dtypes.string)): 109 raise TypeError( 110 "`dataset` must produce scalar `DT_STRING` tensors whereas it " 111 "produces shape {0} and types {1}".format( 112 dataset_ops.get_legacy_output_shapes(dataset), 113 dataset_ops.get_legacy_output_types(dataset))) 114 return gen_experimental_dataset_ops.dataset_to_tf_record( 115 dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access 116