1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental `dataset` API for parsing example.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import structure 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import sparse_tensor 24from tensorflow.python.framework import tensor_spec 25from tensorflow.python.ops import gen_experimental_dataset_ops 26from tensorflow.python.ops import parsing_ops 27from tensorflow.python.ops.ragged import ragged_tensor 28from tensorflow.python.util.tf_export import tf_export 29 30 31class _ParseExampleDataset(dataset_ops.UnaryDataset): 32 """A `Dataset` that parses `example` dataset into a `dict` dataset.""" 33 34 def __init__(self, input_dataset, features, num_parallel_calls, 35 deterministic): 36 self._input_dataset = input_dataset 37 if not structure.are_compatible( 38 input_dataset.element_spec, 39 tensor_spec.TensorSpec([None], dtypes.string)): 40 raise TypeError("Input dataset should be a dataset of vectors of strings") 41 self._num_parallel_calls = num_parallel_calls 42 if deterministic is None: 43 self._deterministic = "default" 44 elif deterministic: 45 self._deterministic = "true" 46 else: 47 self._deterministic = "false" 48 # pylint: disable=protected-access 49 self._features = parsing_ops._prepend_none_dimension(features) 50 # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature 51 params = parsing_ops._ParseOpParams.from_features(self._features, [ 52 parsing_ops.VarLenFeature, parsing_ops.SparseFeature, 53 parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature, 54 parsing_ops.RaggedFeature 55 ]) 56 # pylint: enable=protected-access 57 self._sparse_keys = params.sparse_keys 58 self._sparse_types = params.sparse_types 59 self._ragged_keys = params.ragged_keys 60 self._ragged_value_types = params.ragged_value_types 61 self._ragged_split_types = params.ragged_split_types 62 self._dense_keys = params.dense_keys 63 self._dense_defaults = params.dense_defaults_vec 64 self._dense_shapes = params.dense_shapes_as_proto 65 self._dense_types = params.dense_types 66 input_dataset_shape = dataset_ops.get_legacy_output_shapes( 67 self._input_dataset) 68 69 self._element_spec = {} 70 71 for (key, value_type) in zip(params.sparse_keys, params.sparse_types): 72 self._element_spec[key] = sparse_tensor.SparseTensorSpec( 73 input_dataset_shape.concatenate([None]), value_type) 74 75 for (key, value_type, dense_shape) in zip(params.dense_keys, 76 params.dense_types, 77 params.dense_shapes): 78 self._element_spec[key] = tensor_spec.TensorSpec( 79 input_dataset_shape.concatenate(dense_shape), value_type) 80 81 for (key, value_type, splits_type) in zip(params.ragged_keys, 82 params.ragged_value_types, 83 params.ragged_split_types): 84 self._element_spec[key] = ragged_tensor.RaggedTensorSpec( 85 input_dataset_shape.concatenate([None]), value_type, 1, splits_type) 86 87 variant_tensor = ( 88 gen_experimental_dataset_ops.parse_example_dataset_v2( 89 self._input_dataset._variant_tensor, # pylint: disable=protected-access 90 self._num_parallel_calls, 91 self._dense_defaults, 92 self._sparse_keys, 93 self._dense_keys, 94 self._sparse_types, 95 self._dense_shapes, 96 deterministic=self._deterministic, 97 ragged_keys=self._ragged_keys, 98 ragged_value_types=self._ragged_value_types, 99 ragged_split_types=self._ragged_split_types, 100 **self._flat_structure)) 101 super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor) 102 103 @property 104 def element_spec(self): 105 return self._element_spec 106 107 108# TODO(b/111553342): add arguments names and example names as well. 109@tf_export("data.experimental.parse_example_dataset") 110def parse_example_dataset(features, num_parallel_calls=1, deterministic=None): 111 """A transformation that parses `Example` protos into a `dict` of tensors. 112 113 Parses a number of serialized `Example` protos given in `serialized`. We refer 114 to `serialized` as a batch with `batch_size` many entries of individual 115 `Example` protos. 116 117 This op parses serialized examples into a dictionary mapping keys to `Tensor`, 118 `SparseTensor`, and `RaggedTensor` objects. `features` is a dict from keys to 119 `VarLenFeature`, `RaggedFeature`, `SparseFeature`, and `FixedLenFeature` 120 objects. Each `VarLenFeature` and `SparseFeature` is mapped to a 121 `SparseTensor`; each `RaggedFeature` is mapped to a `RaggedTensor`; and each 122 `FixedLenFeature` is mapped to a `Tensor`. See `tf.io.parse_example` for more 123 details about feature dictionaries. 124 125 Args: 126 features: A `dict` mapping feature keys to `FixedLenFeature`, 127 `VarLenFeature`, `RaggedFeature`, and `SparseFeature` values. 128 num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, 129 representing the number of parsing processes to call in parallel. 130 deterministic: (Optional.) A boolean controlling whether determinism 131 should be traded for performance by allowing elements to be produced out 132 of order if some parsing calls complete faster than others. If 133 `deterministic` is `None`, the 134 `tf.data.Options.experimental_deterministic` dataset option (`True` by 135 default) is used to decide whether to produce elements 136 deterministically. 137 138 Returns: 139 A dataset transformation function, which can be passed to 140 `tf.data.Dataset.apply`. 141 142 Raises: 143 ValueError: if features argument is None. 144 """ 145 if features is None: 146 raise ValueError("Missing: features was %s." % features) 147 148 def _apply_fn(dataset): 149 """Function from `Dataset` to `Dataset` that applies the transformation.""" 150 out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls, 151 deterministic) 152 if any( 153 isinstance(feature, parsing_ops.SparseFeature) or 154 (isinstance(feature, parsing_ops.RaggedFeature) and feature.partitions) 155 for feature in features.values()): 156 # pylint: disable=protected-access 157 # pylint: disable=g-long-lambda 158 out_dataset = out_dataset.map( 159 lambda x: parsing_ops._construct_tensors_for_composite_features( 160 features, x), 161 num_parallel_calls=num_parallel_calls) 162 return out_dataset 163 164 return _apply_fn 165