1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python wrappers for reader Datasets.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.compat import compat 21from tensorflow.python.data.ops import dataset_ops 22from tensorflow.python.data.util import convert 23from tensorflow.python.data.util import structure 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gen_dataset_ops 29from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 30from tensorflow.python.util.tf_export import tf_export 31 32 33# TODO(b/64974358): Increase default buffer size to 256 MB. 34_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 # 256 KB 35 36 37@tf_export("data.TextLineDataset", v1=[]) 38class TextLineDatasetV2(dataset_ops.DatasetSource): 39 """A `Dataset` comprising lines from one or more text files.""" 40 41 def __init__(self, filenames, compression_type=None, buffer_size=None): 42 """Creates a `TextLineDataset`. 43 44 Args: 45 filenames: A `tf.string` tensor containing one or more filenames. 46 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 47 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 48 buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes 49 to buffer. A value of 0 results in the default buffering values chosen 50 based on the compression type. 51 """ 52 self._filenames = ops.convert_to_tensor( 53 filenames, dtype=dtypes.string, name="filenames") 54 self._compression_type = convert.optional_param_to_tensor( 55 "compression_type", 56 compression_type, 57 argument_default="", 58 argument_dtype=dtypes.string) 59 self._buffer_size = convert.optional_param_to_tensor( 60 "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) 61 variant_tensor = gen_dataset_ops.text_line_dataset( 62 self._filenames, self._compression_type, self._buffer_size) 63 super(TextLineDatasetV2, self).__init__(variant_tensor) 64 65 @property 66 def _element_structure(self): 67 return structure.TensorStructure(dtypes.string, []) 68 69 70@tf_export(v1=["data.TextLineDataset"]) 71class TextLineDatasetV1(dataset_ops.DatasetV1Adapter): 72 """A `Dataset` comprising lines from one or more text files.""" 73 74 def __init__(self, filenames, compression_type=None, buffer_size=None): 75 wrapped = TextLineDatasetV2(filenames, compression_type, buffer_size) 76 super(TextLineDatasetV1, self).__init__(wrapped) 77 __init__.__doc__ = TextLineDatasetV2.__init__.__doc__ 78 79 @property 80 def _filenames(self): 81 return self._dataset._filenames # pylint: disable=protected-access 82 83 @_filenames.setter 84 def _filenames(self, value): 85 self._dataset._filenames = value # pylint: disable=protected-access 86 87 88class _TFRecordDataset(dataset_ops.DatasetSource): 89 """A `Dataset` comprising records from one or more TFRecord files.""" 90 91 def __init__(self, filenames, compression_type=None, buffer_size=None): 92 """Creates a `TFRecordDataset`. 93 94 Args: 95 filenames: A `tf.string` tensor containing one or more filenames. 96 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 97 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 98 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 99 bytes in the read buffer. 0 means no buffering. 100 """ 101 # Force the type to string even if filenames is an empty list. 102 self._filenames = ops.convert_to_tensor( 103 filenames, dtypes.string, name="filenames") 104 self._compression_type = convert.optional_param_to_tensor( 105 "compression_type", 106 compression_type, 107 argument_default="", 108 argument_dtype=dtypes.string) 109 self._buffer_size = convert.optional_param_to_tensor( 110 "buffer_size", 111 buffer_size, 112 argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES) 113 variant_tensor = gen_dataset_ops.tf_record_dataset( 114 self._filenames, self._compression_type, self._buffer_size) 115 super(_TFRecordDataset, self).__init__(variant_tensor) 116 117 @property 118 def _element_structure(self): 119 return structure.TensorStructure(dtypes.string, []) 120 121 122class ParallelInterleaveDataset(dataset_ops.UnaryDataset): 123 """A `Dataset` that maps a function over its input and flattens the result.""" 124 125 def __init__(self, input_dataset, map_func, cycle_length, block_length, 126 sloppy, buffer_output_elements, prefetch_input_elements): 127 """See `tf.data.experimental.parallel_interleave()` for details.""" 128 self._input_dataset = input_dataset 129 self._map_func = dataset_ops.StructuredFunctionWrapper( 130 map_func, self._transformation_name(), dataset=input_dataset) 131 if not isinstance(self._map_func.output_structure, 132 dataset_ops.DatasetStructure): 133 raise TypeError("`map_func` must return a `Dataset` object.") 134 self._structure = self._map_func.output_structure._element_structure # pylint: disable=protected-access 135 self._cycle_length = ops.convert_to_tensor( 136 cycle_length, dtype=dtypes.int64, name="cycle_length") 137 self._block_length = ops.convert_to_tensor( 138 block_length, dtype=dtypes.int64, name="block_length") 139 self._sloppy = ops.convert_to_tensor( 140 sloppy, dtype=dtypes.bool, name="sloppy") 141 self._buffer_output_elements = convert.optional_param_to_tensor( 142 "buffer_output_elements", 143 buffer_output_elements, 144 argument_default=2 * block_length) 145 self._prefetch_input_elements = convert.optional_param_to_tensor( 146 "prefetch_input_elements", 147 prefetch_input_elements, 148 argument_default=2 * cycle_length) 149 variant_tensor = ged_ops.experimental_parallel_interleave_dataset( 150 self._input_dataset._variant_tensor, # pylint: disable=protected-access 151 self._map_func.function.captured_inputs, 152 self._cycle_length, 153 self._block_length, 154 self._sloppy, 155 self._buffer_output_elements, 156 self._prefetch_input_elements, 157 f=self._map_func.function, 158 **dataset_ops.flat_structure(self)) 159 super(ParallelInterleaveDataset, self).__init__(input_dataset, 160 variant_tensor) 161 162 def _functions(self): 163 return [self._map_func] 164 165 @property 166 def _element_structure(self): 167 return self._structure 168 169 def _transformation_name(self): 170 return "tf.data.experimental.parallel_interleave()" 171 172 173@tf_export("data.TFRecordDataset", v1=[]) 174class TFRecordDatasetV2(dataset_ops.DatasetV2): 175 """A `Dataset` comprising records from one or more TFRecord files.""" 176 177 def __init__(self, filenames, compression_type=None, buffer_size=None, 178 num_parallel_reads=None): 179 """Creates a `TFRecordDataset` to read one or more TFRecord files. 180 181 NOTE: The `num_parallel_reads` argument can be used to improve performance 182 when reading from a remote filesystem. 183 184 Args: 185 filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or 186 more filenames. 187 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 188 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 189 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 190 bytes in the read buffer. 0 means no buffering. 191 num_parallel_reads: (Optional.) A `tf.int64` scalar representing the 192 number of files to read in parallel. Defaults to reading files 193 sequentially. 194 195 Raises: 196 TypeError: If any argument does not have the expected type. 197 ValueError: If any argument does not have the expected shape. 198 """ 199 if isinstance(filenames, dataset_ops.DatasetV2): 200 if dataset_ops.get_legacy_output_types(filenames) != dtypes.string: 201 raise TypeError( 202 "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.") 203 if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with( 204 tensor_shape.scalar()): 205 raise ValueError( 206 "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` " 207 "elements.") 208 else: 209 filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string) 210 filenames = array_ops.reshape(filenames, [-1], name="flat_filenames") 211 filenames = dataset_ops.DatasetV2.from_tensor_slices(filenames) 212 213 self._filenames = filenames 214 self._compression_type = compression_type 215 self._buffer_size = buffer_size 216 self._num_parallel_reads = num_parallel_reads 217 218 def read_one_file(filename): 219 return _TFRecordDataset(filename, compression_type, buffer_size) 220 221 if num_parallel_reads is None: 222 self._impl = filenames.flat_map(read_one_file) 223 else: 224 self._impl = ParallelInterleaveDataset( 225 filenames, read_one_file, cycle_length=num_parallel_reads, 226 block_length=1, sloppy=False, buffer_output_elements=None, 227 prefetch_input_elements=None) 228 variant_tensor = self._impl._variant_tensor # pylint: disable=protected-access 229 super(TFRecordDatasetV2, self).__init__(variant_tensor) 230 231 def _clone(self, 232 filenames=None, 233 compression_type=None, 234 buffer_size=None, 235 num_parallel_reads=None): 236 return TFRecordDatasetV2(filenames or self._filenames, 237 compression_type or self._compression_type, 238 buffer_size or self._buffer_size, 239 num_parallel_reads or self._num_parallel_reads) 240 241 def _inputs(self): 242 return self._impl._inputs() # pylint: disable=protected-access 243 244 @property 245 def _element_structure(self): 246 return structure.TensorStructure(dtypes.string, []) 247 248 249@tf_export(v1=["data.TFRecordDataset"]) 250class TFRecordDatasetV1(dataset_ops.DatasetV1Adapter): 251 """A `Dataset` comprising records from one or more TFRecord files.""" 252 253 def __init__(self, filenames, compression_type=None, buffer_size=None, 254 num_parallel_reads=None): 255 wrapped = TFRecordDatasetV2( 256 filenames, compression_type, buffer_size, num_parallel_reads) 257 super(TFRecordDatasetV1, self).__init__(wrapped) 258 __init__.__doc__ = TFRecordDatasetV2.__init__.__doc__ 259 260 def _clone(self, 261 filenames=None, 262 compression_type=None, 263 buffer_size=None, 264 num_parallel_reads=None): 265 # pylint: disable=protected-access 266 return TFRecordDatasetV1( 267 filenames or self._dataset._filenames, 268 compression_type or self._dataset._compression_type, 269 buffer_size or self._dataset._buffer_size, 270 num_parallel_reads or self._dataset._num_parallel_reads) 271 272 @property 273 def _filenames(self): 274 return self._dataset._filenames # pylint: disable=protected-access 275 276 @_filenames.setter 277 def _filenames(self, value): 278 self._dataset._filenames = value # pylint: disable=protected-access 279 280 281@tf_export("data.FixedLengthRecordDataset", v1=[]) 282class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource): 283 """A `Dataset` of fixed-length records from one or more binary files.""" 284 285 def __init__(self, 286 filenames, 287 record_bytes, 288 header_bytes=None, 289 footer_bytes=None, 290 buffer_size=None, 291 compression_type=None): 292 """Creates a `FixedLengthRecordDataset`. 293 294 Args: 295 filenames: A `tf.string` tensor containing one or more filenames. 296 record_bytes: A `tf.int64` scalar representing the number of bytes in 297 each record. 298 header_bytes: (Optional.) A `tf.int64` scalar representing the number of 299 bytes to skip at the start of a file. 300 footer_bytes: (Optional.) A `tf.int64` scalar representing the number of 301 bytes to ignore at the end of a file. 302 buffer_size: (Optional.) A `tf.int64` scalar representing the number of 303 bytes to buffer when reading. 304 compression_type: (Optional.) A `tf.string` scalar evaluating to one of 305 `""` (no compression), `"ZLIB"`, or `"GZIP"`. 306 """ 307 self._filenames = ops.convert_to_tensor( 308 filenames, dtype=dtypes.string, name="filenames") 309 self._record_bytes = ops.convert_to_tensor( 310 record_bytes, dtype=dtypes.int64, name="record_bytes") 311 312 self._header_bytes = convert.optional_param_to_tensor( 313 "header_bytes", header_bytes) 314 self._footer_bytes = convert.optional_param_to_tensor( 315 "footer_bytes", footer_bytes) 316 self._buffer_size = convert.optional_param_to_tensor( 317 "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) 318 self._compression_type = convert.optional_param_to_tensor( 319 "compression_type", 320 compression_type, 321 argument_default="", 322 argument_dtype=dtypes.string) 323 if (self._compression_type is not None or 324 compat.forward_compatible(2018, 11, 30)): 325 variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2( 326 self._filenames, self._header_bytes, self._record_bytes, 327 self._footer_bytes, self._buffer_size, self._compression_type) 328 else: 329 variant_tensor = gen_dataset_ops.fixed_length_record_dataset( 330 self._filenames, self._header_bytes, self._record_bytes, 331 self._footer_bytes, self._buffer_size) 332 super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor) 333 334 @property 335 def _element_structure(self): 336 return structure.TensorStructure(dtypes.string, []) 337 338 339@tf_export(v1=["data.FixedLengthRecordDataset"]) 340class FixedLengthRecordDatasetV1(dataset_ops.DatasetV1Adapter): 341 """A `Dataset` of fixed-length records from one or more binary files.""" 342 343 def __init__(self, 344 filenames, 345 record_bytes, 346 header_bytes=None, 347 footer_bytes=None, 348 buffer_size=None, 349 compression_type=None): 350 wrapped = FixedLengthRecordDatasetV2( 351 filenames, record_bytes, header_bytes, footer_bytes, buffer_size, 352 compression_type) 353 super(FixedLengthRecordDatasetV1, self).__init__(wrapped) 354 __init__.__doc__ = FixedLengthRecordDatasetV2.__init__.__doc__ 355 356 @property 357 def _filenames(self): 358 return self._dataset._filenames # pylint: disable=protected-access 359 360 @_filenames.setter 361 def _filenames(self, value): 362 self._dataset._filenames = value # pylint: disable=protected-access 363 364 365# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep 366# these aliases in place. 367FixedLengthRecordDataset = FixedLengthRecordDatasetV1 368TFRecordDataset = TFRecordDatasetV1 369TextLineDataset = TextLineDatasetV1 370