1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Feature configuration for tf.io.parse_example."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import re
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import check_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import sparse_ops
32from tensorflow.python.ops.ragged import ragged_math_ops
33from tensorflow.python.ops.ragged import ragged_tensor
34from tensorflow.python.platform import tf_logging
35from tensorflow.python.util.tf_export import tf_export
36
37
38# TODO(b/122887740) Refactor code:
39#   * Move input verification to feature configuration objects (e.g.,
40#     VarLenFeature should check that dtype is a valid dtype).
41#   * Add an _add_feature() method to each feature configuration object
42#     (rather than using a dispatch table in _ParseOpParams._add_feature).
43#   * Update _construct_tensors_for_composite_features() to call a method
44#     on the feature object (rather than using dispatch).
45
46
47@tf_export("io.VarLenFeature", v1=["VarLenFeature", "io.VarLenFeature"])
48class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
49  """Configuration for parsing a variable-length input feature.
50
51  Fields:
52    dtype: Data type of input.
53  """
54  pass
55
56
57@tf_export("io.RaggedFeature")
58class RaggedFeature(
59    collections.namedtuple(
60        "RaggedFeature",
61        ["dtype", "value_key", "partitions", "row_splits_dtype", "validate"])):
62  """Configuration for passing a RaggedTensor input feature.
63
64  `value_key` specifies the feature key for a variable-length list of values;
65  and `partitions` specifies zero or more feature keys for partitioning those
66  values into higher dimensions.  Each element of `partitions` must be one of
67  the following:
68
69    * `tf.io.RaggedFeature.RowSplits(key: string)`
70    * `tf.io.RaggedFeature.RowLengths(key: string)`
71    * `tf.io.RaggedFeature.RowStarts(key: string)`
72    * `tf.io.RaggedFeature.RowLimits(key: string)`
73    * `tf.io.RaggedFeature.ValueRowIds(key: string)`
74    * `tf.io.RaggedFeature.UniformRowLength(length: int)`.
75
76  Where `key` is a feature key whose values are used to partition the values.
77  Partitions are listed from outermost to innermost.
78
79  * If `len(partitions) == 0` (the default), then:
80
81    * A feature from a single `tf.Example` is parsed into a 1D `tf.Tensor`.
82    * A feature from a batch of `tf.Example`s is parsed into a 2D
83      `tf.RaggedTensor`, where the outer dimension is the batch dimension, and
84      the inner (ragged) dimension is the feature length in each example.
85
86  * If `len(partitions) == 1`, then:
87
88    * A feature from a single `tf.Example` is parsed into a 2D
89      `tf.RaggedTensor`, where the values taken from the `value_key` are
90      separated into rows using the partition key.
91    * A feature from a batch of `tf.Example`s is parsed into a 3D
92      `tf.RaggedTensor`, where the outer dimension is the batch dimension,
93      the two inner dimensions are formed by separating the `value_key` values
94      from each example into rows using that example's partition key.
95
96  * If `len(partitions) > 1`, then:
97
98    * A feature from a single `tf.Example` is parsed into a `tf.RaggedTensor`
99      whose rank is `len(partitions)+1`, and whose ragged_rank is
100      `len(partitions)`.
101
102    * A feature from a batch of `tf.Example`s is parsed into a `tf.RaggedTensor`
103      whose rank is `len(partitions)+2` and whose ragged_rank is
104      `len(partitions)+1`, where the outer dimension is the batch dimension.
105
106  There is one exception: if the final (i.e., innermost) element(s) of
107  `partitions` are `UniformRowLength`s, then the values are simply reshaped (as
108  a higher-dimensional `tf.Tensor`), rather than being wrapped in a
109  `tf.RaggedTensor`.
110
111  #### Examples
112
113  >>> import google.protobuf.text_format as pbtext
114  >>> example_batch = [
115  ...   pbtext.Merge(r'''
116  ...     features {
117  ...       feature {key: "v" value {int64_list {value: [3, 1, 4, 1, 5, 9]}}}
118  ...       feature {key: "s1" value {int64_list {value: [0, 2, 3, 3, 6]}}}
119  ...       feature {key: "s2" value {int64_list {value: [0, 2, 3, 4]}}}
120  ...     }''', tf.train.Example()).SerializeToString(),
121  ...   pbtext.Merge(r'''
122  ...     features {
123  ...       feature {key: "v" value {int64_list {value: [2, 7, 1, 8, 2, 8, 1]}}}
124  ...       feature {key: "s1" value {int64_list {value: [0, 3, 4, 5, 7]}}}
125  ...       feature {key: "s2" value {int64_list {value: [0, 1, 1, 4]}}}
126  ...     }''', tf.train.Example()).SerializeToString()]
127
128  >>> features = {
129  ...     # Zero partitions: returns 1D tf.Tensor for each Example.
130  ...     'f1': tf.io.RaggedFeature(value_key="v", dtype=tf.int64),
131  ...     # One partition: returns 2D tf.RaggedTensor for each Example.
132  ...     'f2': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[
133  ...         tf.io.RaggedFeature.RowSplits("s1")]),
134  ...     # Two partitions: returns 3D tf.RaggedTensor for each Example.
135  ...     'f3': tf.io.RaggedFeature(value_key="v", dtype=tf.int64, partitions=[
136  ...         tf.io.RaggedFeature.RowSplits("s2"),
137  ...         tf.io.RaggedFeature.RowSplits("s1")])
138  ... }
139
140  >>> feature_dict = tf.io.parse_single_example(example_batch[0], features)
141  >>> for (name, val) in sorted(feature_dict.items()):
142  ...   print('%s: %s' % (name, val))
143  f1: tf.Tensor([3 1 4 1 5 9], shape=(6,), dtype=int64)
144  f2: <tf.RaggedTensor [[3, 1], [4], [], [1, 5, 9]]>
145  f3: <tf.RaggedTensor [[[3, 1], [4]], [[]], [[1, 5, 9]]]>
146
147  >>> feature_dict = tf.io.parse_example(example_batch, features)
148  >>> for (name, val) in sorted(feature_dict.items()):
149  ...   print('%s: %s' % (name, val))
150  f1: <tf.RaggedTensor [[3, 1, 4, 1, 5, 9],
151                        [2, 7, 1, 8, 2, 8, 1]]>
152  f2: <tf.RaggedTensor [[[3, 1], [4], [], [1, 5, 9]],
153                        [[2, 7, 1], [8], [2], [8, 1]]]>
154  f3: <tf.RaggedTensor [[[[3, 1], [4]], [[]], [[1, 5, 9]]],
155                        [[[2, 7, 1]], [], [[8], [2], [8, 1]]]]>
156
157  Fields:
158    dtype: Data type of the `RaggedTensor`.  Must be one of:
159      `tf.dtypes.int64`, `tf.dtypes.float32`, `tf.dtypes.string`.
160    value_key: (Optional.) Key for a `Feature` in the input `Example`, whose
161      parsed `Tensor` will be the resulting `RaggedTensor.flat_values`.  If
162      not specified, then it defaults to the key for this `RaggedFeature`.
163    partitions: (Optional.) A list of objects specifying the row-partitioning
164      tensors (from outermost to innermost).  Each entry in this list must be
165      one of:
166        * `tf.io.RaggedFeature.RowSplits(key: string)`
167        * `tf.io.RaggedFeature.RowLengths(key: string)`
168        * `tf.io.RaggedFeature.RowStarts(key: string)`
169        * `tf.io.RaggedFeature.RowLimits(key: string)`
170        * `tf.io.RaggedFeature.ValueRowIds(key: string)`
171        * `tf.io.RaggedFeature.UniformRowLength(length: int)`.
172      Where `key` is a key for a `Feature` in the input `Example`, whose parsed
173      `Tensor` will be the resulting row-partitioning tensor.
174    row_splits_dtype: (Optional.) Data type for the row-partitioning tensor(s).
175      One of `int32` or `int64`.  Defaults to `int32`.
176    validate: (Optional.) Boolean indicating whether or not to validate that
177      the input values form a valid RaggedTensor.  Defaults to `False`.
178  """
179
180  # pylint: disable=invalid-name
181  RowSplits = collections.namedtuple("RowSplits", ["key"])
182  RowLengths = collections.namedtuple("RowLengths", ["key"])
183  RowStarts = collections.namedtuple("RowStarts", ["key"])
184  RowLimits = collections.namedtuple("RowLimits", ["key"])
185  ValueRowIds = collections.namedtuple("ValueRowIds", ["key"])
186  UniformRowLength = collections.namedtuple("UniformRowLength", ["length"])
187  # pylint: enable=invalid-name
188
189  _PARTITION_TYPES = (RowSplits, RowLengths, RowStarts, RowLimits, ValueRowIds,
190                      UniformRowLength)
191
192  def __new__(cls,
193              dtype,
194              value_key=None,
195              partitions=(),
196              row_splits_dtype=dtypes.int32,
197              validate=False):
198    if value_key is not None:
199      if not isinstance(value_key, str):
200        raise ValueError("value_key must be a string; got %r" % value_key)
201      if not value_key:
202        raise ValueError("value_key may not be empty")
203    dtype = dtypes.as_dtype(dtype)
204    if dtype not in (dtypes.int64, dtypes.float32, dtypes.string):
205      raise ValueError("dtypes must be int64, float32, or bytes; got %r" %
206                       dtype)
207    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
208    if row_splits_dtype not in (dtypes.int32, dtypes.int64):
209      raise ValueError("row_splits_dtype must be int32 or int64; got %r" %
210                       row_splits_dtype)
211    if not isinstance(partitions, (list, tuple)):
212      raise TypeError("partitions must be a list or tuple")
213    for partition in partitions:
214      if not isinstance(partition, cls._PARTITION_TYPES):
215        raise TypeError("partitions must be a list of partition objects %s;"
216                        " got: %r" % (cls._PARTITION_TYPES, partition))
217    if not isinstance(validate, bool):
218      raise TypeError("validate must be a bool; got %r" % validate)
219    return super(RaggedFeature, cls).__new__(cls, dtype, value_key, partitions,
220                                             row_splits_dtype, validate)
221
222
223@tf_export("io.SparseFeature", v1=["io.SparseFeature", "SparseFeature"])
224class SparseFeature(
225    collections.namedtuple(
226        "SparseFeature",
227        ["index_key", "value_key", "dtype", "size", "already_sorted"])):
228  """Configuration for parsing a sparse input feature from an `Example`.
229
230  Note, preferably use `VarLenFeature` (possibly in combination with a
231  `SequenceExample`) in order to parse out `SparseTensor`s instead of
232  `SparseFeature` due to its simplicity.
233
234  Closely mimicking the `SparseTensor` that will be obtained by parsing an
235  `Example` with a `SparseFeature` config, a `SparseFeature` contains a
236
237  * `value_key`: The name of key for a `Feature` in the `Example` whose parsed
238    `Tensor` will be the resulting `SparseTensor.values`.
239
240  * `index_key`: A list of names - one for each dimension in the resulting
241    `SparseTensor` whose `indices[i][dim]` indicating the position of
242    the `i`-th value in the `dim` dimension will be equal to the `i`-th value in
243    the Feature with key named `index_key[dim]` in the `Example`.
244
245  * `size`: A list of ints for the resulting `SparseTensor.dense_shape`.
246
247  For example, we can represent the following 2D `SparseTensor`
248
249  ```python
250  SparseTensor(indices=[[3, 1], [20, 0]],
251               values=[0.5, -1.0]
252               dense_shape=[100, 3])
253  ```
254
255  with an `Example` input proto
256
257  ```python
258  features {
259    feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
260    feature { key: "ix0" value { int64_list { value: [ 3, 20 ] } } }
261    feature { key: "ix1" value { int64_list { value: [ 1, 0 ] } } }
262  }
263  ```
264
265  and `SparseFeature` config with 2 `index_key`s
266
267  ```python
268  SparseFeature(index_key=["ix0", "ix1"],
269                value_key="val",
270                dtype=tf.float32,
271                size=[100, 3])
272  ```
273
274  Fields:
275    index_key: A single string name or a list of string names of index features.
276      For each key the underlying feature's type must be `int64` and its length
277      must always match that of the `value_key` feature.
278      To represent `SparseTensor`s with a `dense_shape` of `rank` higher than 1
279      a list of length `rank` should be used.
280    value_key: Name of value feature.  The underlying feature's type must
281      be `dtype` and its length must always match that of all the `index_key`s'
282      features.
283    dtype: Data type of the `value_key` feature.
284    size: A Python int or list thereof specifying the dense shape. Should be a
285      list if and only if `index_key` is a list. In that case the list must be
286      equal to the length of `index_key`. Each for each entry `i` all values in
287      the `index_key`[i] feature must be in `[0, size[i])`.
288    already_sorted: A Python boolean to specify whether the values in
289      `value_key` are already sorted by their index position. If so skip
290      sorting. False by default (optional).
291  """
292
293  def __new__(cls, index_key, value_key, dtype, size, already_sorted=False):
294    return super(SparseFeature, cls).__new__(
295        cls, index_key, value_key, dtype, size, already_sorted)
296
297
298@tf_export("io.FixedLenFeature", v1=["io.FixedLenFeature", "FixedLenFeature"])
299class FixedLenFeature(collections.namedtuple(
300    "FixedLenFeature", ["shape", "dtype", "default_value"])):
301  """Configuration for parsing a fixed-length input feature.
302
303  To treat sparse input as dense, provide a `default_value`; otherwise,
304  the parse functions will fail on any examples missing this feature.
305
306  Fields:
307    shape: Shape of input data.
308    dtype: Data type of input.
309    default_value: Value to be used if an example is missing this feature. It
310        must be compatible with `dtype` and of the specified `shape`.
311  """
312
313  def __new__(cls, shape, dtype, default_value=None):
314    return super(FixedLenFeature, cls).__new__(
315        cls, shape, dtype, default_value)
316
317
318@tf_export("io.FixedLenSequenceFeature",
319           v1=["io.FixedLenSequenceFeature", "FixedLenSequenceFeature"])
320class FixedLenSequenceFeature(collections.namedtuple(
321    "FixedLenSequenceFeature",
322    ["shape", "dtype", "allow_missing", "default_value"])):
323  """Configuration for parsing a variable-length input feature into a `Tensor`.
324
325  The resulting `Tensor` of parsing a single `SequenceExample` or `Example` has
326  a static `shape` of `[None] + shape` and the specified `dtype`.
327  The resulting `Tensor` of parsing a `batch_size` many `Example`s has
328  a static `shape` of `[batch_size, None] + shape` and the specified `dtype`.
329  The entries in the `batch` from different `Examples` will be padded with
330  `default_value` to the maximum length present in the `batch`.
331
332  To treat a sparse input as dense, provide `allow_missing=True`; otherwise,
333  the parse functions will fail on any examples missing this feature.
334
335  Fields:
336    shape: Shape of input data for dimension 2 and higher. First dimension is
337      of variable length `None`.
338    dtype: Data type of input.
339    allow_missing: Whether to allow this feature to be missing from a feature
340      list item. Is available only for parsing `SequenceExample` not for
341      parsing `Examples`.
342    default_value: Scalar value to be used to pad multiple `Example`s to their
343      maximum length. Irrelevant for parsing a single `Example` or
344      `SequenceExample`. Defaults to "" for dtype string and 0 otherwise
345      (optional).
346  """
347
348  def __new__(cls, shape, dtype, allow_missing=False, default_value=None):
349    return super(FixedLenSequenceFeature, cls).__new__(
350        cls, shape, dtype, allow_missing, default_value)
351
352
353class _ParseOpParams(object):
354  """Raw parameters used by `gen_parsing_ops`.
355
356  Attributes:
357    sparse_keys: A list of string keys in the examples' features. The results
358      for these keys will be returned as `SparseTensor` objects.
359    sparse_types: A list of `DTypes` of the same length as `sparse_keys`. Only
360      `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
361      (`BytesList`) are supported.
362    dense_keys: A list of string keys in the examples' features. The results for
363      these keys will be returned as `Tensor`s
364    dense_types: A list of DTypes of the same length as `dense_keys`. Only
365      `tf.float32` (`FloatList`), `tf.int64` (`Int64List`), and `tf.string`
366      (`BytesList`) are supported.
367    dense_defaults: A dict mapping string keys to `Tensor`s. The keys of the
368      dict must match the dense_keys of the feature.
369    dense_shapes: A list of tuples with the same length as `dense_keys`. The
370      shape of the data for each dense feature referenced by `dense_keys`.
371      Required for any input tensors identified by `dense_keys`.  Must be either
372      fully defined, or may contain an unknown first dimension. An unknown first
373      dimension means the feature is treated as having a variable number of
374      blocks, and the output shape along this dimension is considered unknown at
375      graph build time.  Padding is applied for minibatch elements smaller than
376      the maximum number of blocks for the given feature along this dimension.
377    ragged_keys: A list of string keys in the examples' features.  The
378      results for these keys will be returned as `RaggedTensor` objects.
379    ragged_value_types: A list of `DTypes` of the same length as `ragged_keys`,
380      specifying the value type for each ragged feature.  Must be one of:
381      `tf.float32`, `tf.int64`, `tf.string`.
382    ragged_split_types: A list of `DTypes` of the same length as `ragged_keys`,
383      specifying the row_splits type for each ragged feature.  Must be one of:
384      `tf.int32`, `tf.int64`.
385    dense_shapes_as_proto: dense_shapes converted to TensorShapeProto.
386    dense_defaults_vec: A vector of `Tensor`s containing the default values,
387      corresponding 1:1 with `dense_keys`.
388    num_features: The total number of feature keys.
389  """
390
391  def __init__(self,
392               sparse_keys=None,
393               sparse_types=None,
394               dense_keys=None,
395               dense_types=None,
396               dense_defaults=None,
397               dense_shapes=None,
398               ragged_keys=None,
399               ragged_value_types=None,
400               ragged_split_types=None):
401    # Note: we use an OrderedDict for dense_defaults, to ensure consistent
402    # graph construction order for _e2e_test.
403    dense_defaults = (
404        collections.OrderedDict() if dense_defaults is None else dense_defaults)
405    sparse_keys = [] if sparse_keys is None else sparse_keys
406    sparse_types = [] if sparse_types is None else sparse_types
407    dense_keys = [] if dense_keys is None else dense_keys
408    dense_types = [] if dense_types is None else dense_types
409    dense_shapes = ([[]] *
410                    len(dense_keys) if dense_shapes is None else dense_shapes)
411    ragged_keys = [] if ragged_keys is None else ragged_keys
412    ragged_value_types = ([]
413                          if ragged_value_types is None else ragged_value_types)
414    ragged_split_types = ([]
415                          if ragged_split_types is None else ragged_split_types)
416    self.sparse_keys = sparse_keys
417    self.sparse_types = [dtypes.as_dtype(t) for t in sparse_types]
418    self.dense_keys = dense_keys
419    self.dense_types = [dtypes.as_dtype(t) for t in dense_types]
420    self.dense_shapes = [tensor_shape.as_shape(s) for s in dense_shapes]
421    self.dense_defaults = dense_defaults
422    self.ragged_keys = ragged_keys
423    self.ragged_value_types = [dtypes.as_dtype(t) for t in ragged_value_types]
424    self.ragged_split_types = [dtypes.as_dtype(t) for t in ragged_split_types]
425    self._validate()
426
427  @classmethod
428  def from_features(cls, features, types):
429    """Builds _ParseOpParams for a given set of features and allowed types.
430
431    Args:
432      features: A `dict` mapping feature keys to objects of a type in `types`.
433      types: Type of features to allow, among `FixedLenFeature`,
434        `VarLenFeature`, `SparseFeature`, and `FixedLenSequenceFeature`.
435
436    Returns:
437      A `_ParseOpParams` containing the raw parameters for `gen_parsing_ops`.
438
439    Raises:
440      ValueError: if `features` contains an item not in `types`, or an invalid
441          feature.
442      ValueError: if sparse and dense key sets intersect.
443      ValueError: if input lengths do not match up.
444    """
445    params = cls()
446    if features:
447      # NOTE: We iterate over sorted keys to keep things deterministic.
448      for key in sorted(features.keys()):
449        feature = features[key]
450        if not isinstance(feature, tuple(types)):
451          raise ValueError("Unsupported %s %s for key '%s')." %
452                           (type(feature).__name__, feature, key))
453        params._add_feature(key, feature)  # pylint: disable=protected-access
454    params._validate()  # pylint: disable=protected-access
455    return params
456
457  @property
458  def dense_shapes_as_proto(self):
459    return [shape.as_proto() for shape in self.dense_shapes]
460
461  @property
462  def num_features(self):
463    return len(self.dense_keys) + len(self.sparse_keys) + len(self.ragged_keys)
464
465  @property
466  def dense_defaults_vec(self):
467    return [
468        self._make_dense_default(k, s, t)
469        for k, s, t in zip(self.dense_keys, self.dense_shapes, self.dense_types)
470    ]
471
472  def _make_dense_default(self, key, shape, dtype):
473    """Construct the default value tensor for a specified dense feature.
474
475    Args:
476      key: The key string identifying the dense feature.
477      shape: The dense feature's shape.
478      dtype: The dense feature's dtype.
479
480    Returns:
481      A Tensor.
482    """
483    default_value = self.dense_defaults.get(key)
484    if (shape.ndims is not None and shape.ndims > 0 and
485        shape.dims[0].value is None):
486      # Variable stride dense shape, the default value should be a
487      # scalar padding value.
488      if default_value is None:
489        default_value = ops.convert_to_tensor(
490            "" if dtype == dtypes.string else 0, dtype=dtype)
491      else:
492        # Reshape to a scalar to ensure user gets an error if they
493        # provide a tensor that's not intended to be a padding value
494        # (0 or 2+ elements).
495        key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
496        default_value = ops.convert_to_tensor(
497            default_value, dtype=dtype, name=key_name)
498        default_value = array_ops.reshape(default_value, [])
499    else:
500      if default_value is None:
501        default_value = constant_op.constant([], dtype=dtype)
502      elif not isinstance(default_value, ops.Tensor):
503        key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
504        default_value = ops.convert_to_tensor(
505            default_value, dtype=dtype, name=key_name)
506        default_value = array_ops.reshape(default_value, shape)
507
508    return default_value
509
510  def _add_feature(self, key, feature):
511    """Adds the specified feature to this ParseOpParams."""
512    if isinstance(feature, VarLenFeature):
513      self._add_varlen_feature(key, feature)
514    elif isinstance(feature, SparseFeature):
515      self._add_sparse_feature(key, feature)
516    elif isinstance(feature, FixedLenFeature):
517      self._add_fixed_len_feature(key, feature)
518    elif isinstance(feature, FixedLenSequenceFeature):
519      self._add_fixed_len_sequence_feature(key, feature)
520    elif isinstance(feature, RaggedFeature):
521      self._add_ragged_feature(key, feature)
522    else:
523      raise ValueError("Invalid feature %s:%s." % (key, feature))
524
525  def _add_varlen_feature(self, key, feature):
526    """Adds a VarLenFeature."""
527    if not feature.dtype:
528      raise ValueError("Missing type for feature %s." % key)
529    self._add_sparse_key(key, feature.dtype)
530
531  def _add_sparse_key(self, key, dtype):
532    """Adds a sparse key & dtype, checking for duplicates."""
533    if key in self.sparse_keys:
534      original_dtype = self.sparse_types[self.sparse_keys.index(key)]
535      if original_dtype != dtype:
536        raise ValueError("Conflicting type %s vs %s for feature %s." %
537                         (original_dtype, dtype, key))
538    else:
539      self.sparse_keys.append(key)
540      self.sparse_types.append(dtype)
541
542  def _add_sparse_feature(self, key, feature):
543    """Adds a SparseFeature."""
544
545    if not feature.index_key:
546      raise ValueError("Missing index_key for SparseFeature %s." % (feature,))
547    if not feature.value_key:
548      raise ValueError("Missing value_key for SparseFeature %s." % (feature,))
549    if not feature.dtype:
550      raise ValueError("Missing type for feature %s." % key)
551    index_keys = feature.index_key
552    if isinstance(index_keys, str):
553      index_keys = [index_keys]
554    elif len(index_keys) > 1:
555      tf_logging.warning("SparseFeature is a complicated feature config "
556                         "and should only be used after careful "
557                         "consideration of VarLenFeature.")
558    for index_key in sorted(index_keys):
559      self._add_sparse_key(index_key, dtypes.int64)
560    self._add_sparse_key(feature.value_key, feature.dtype)
561
562  def _add_fixed_len_feature(self, key, feature):
563    """Adds a FixedLenFeature."""
564    if not feature.dtype:
565      raise ValueError("Missing type for feature %s." % key)
566    if feature.shape is None:
567      raise ValueError("Missing shape for feature %s." % key)
568    feature_tensor_shape = tensor_shape.as_shape(feature.shape)
569    if (feature.shape and feature_tensor_shape.ndims and
570        feature_tensor_shape.dims[0].value is None):
571      raise ValueError("First dimension of shape for feature %s unknown. "
572                       "Consider using FixedLenSequenceFeature." % key)
573    if (feature.shape is not None and
574        not feature_tensor_shape.is_fully_defined()):
575      raise ValueError("All dimensions of shape for feature %s need to be "
576                       "known but received %s." % (key, str(feature.shape)))
577    self.dense_keys.append(key)
578    self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
579    self.dense_types.append(feature.dtype)
580    if feature.default_value is not None:
581      self.dense_defaults[key] = feature.default_value
582
583  def _add_fixed_len_sequence_feature(self, key, feature):
584    """Adds a FixedLenSequenceFeature."""
585    if not feature.dtype:
586      raise ValueError("Missing type for feature %s." % key)
587    if feature.shape is None:
588      raise ValueError("Missing shape for feature %s." % key)
589    self.dense_keys.append(key)
590    self.dense_shapes.append(tensor_shape.as_shape(feature.shape))
591    self.dense_types.append(feature.dtype)
592    if feature.allow_missing:
593      self.dense_defaults[key] = None
594    if feature.default_value is not None:
595      self.dense_defaults[key] = feature.default_value
596
597  def _add_ragged_key(self, key, value_type, split_type):
598    """Adds a ragged key & dtype, checking for duplicates."""
599    if key in self.ragged_keys:
600      original_value_type = self.ragged_value_types[self.ragged_keys.index(key)]
601      original_split_type = self.ragged_split_types[self.ragged_keys.index(key)]
602      if original_value_type != value_type:
603        raise ValueError("Conflicting type %s vs %s for feature %s." %
604                         (original_value_type, value_type, key))
605      if original_split_type != split_type:
606        raise ValueError("Conflicting partition type %s vs %s for feature %s." %
607                         (original_split_type, split_type, key))
608    else:
609      self.ragged_keys.append(key)
610      self.ragged_value_types.append(value_type)
611      self.ragged_split_types.append(split_type)
612
613  def _add_ragged_feature(self, key, feature):
614    """Adds a RaggedFeature."""
615    value_key = key if feature.value_key is None else feature.value_key
616    self._add_ragged_key(value_key, feature.dtype, feature.row_splits_dtype)
617    for partition in feature.partitions:
618      if not isinstance(partition, RaggedFeature.UniformRowLength):
619        self._add_ragged_key(partition.key, dtypes.int64,
620                             feature.row_splits_dtype)
621
622  def _validate(self):
623    """Validates the features in this ParseOpParams."""
624    if len(self.dense_shapes) != len(self.dense_keys):
625      raise ValueError(
626          "len(self.dense_shapes) != len(self.dense_keys): %d vs %d" %
627          (len(self.dense_shapes), len(self.dense_keys)))
628    if len(self.dense_types) != len(self.dense_keys):
629      raise ValueError(
630          "len(self.dense_types) != len(self.dense_keys): %d vs %d" %
631          (len(self.dense_types), len(self.dense_keys)))
632    if len(self.sparse_types) != len(self.sparse_keys):
633      raise ValueError(
634          "len(self.sparse_types) != len(self.sparse_keys): %d vs %d" %
635          (len(self.sparse_types), len(self.sparse_keys)))
636    if len(self.ragged_value_types) != len(self.ragged_keys):
637      raise ValueError(
638          "len(self.ragged_value_types) != len(self.ragged_keys): %d vs %d" %
639          (len(self.ragged_value_types), len(self.ragged_keys)))
640    if len(self.ragged_split_types) != len(self.ragged_keys):
641      raise ValueError(
642          "len(self.ragged_split_types) != len(self.ragged_keys): %d vs %d" %
643          (len(self.ragged_split_types), len(self.ragged_keys)))
644
645    dense_key_set = set(self.dense_keys)
646    sparse_key_set = set(self.sparse_keys)
647    ragged_key_set = set(self.ragged_keys)
648    if not dense_key_set.isdisjoint(sparse_key_set):
649      raise ValueError(
650          "Dense and sparse keys must not intersect; intersection: %s" %
651          dense_key_set.intersection(sparse_key_set))
652    if not dense_key_set.isdisjoint(ragged_key_set):
653      raise ValueError(
654          "Dense and ragged keys must not intersect; intersection: %s" %
655          dense_key_set.intersection(ragged_key_set))
656    if not ragged_key_set.isdisjoint(sparse_key_set):
657      raise ValueError(
658          "Ragged and sparse keys must not intersect; intersection: %s" %
659          ragged_key_set.intersection(sparse_key_set))
660
661
662def _construct_tensors_for_composite_features(features, tensor_dict):
663  """Creates tensors for SparseFeatures and RaggedFeatures.
664
665  Constructs new dict based on `tensor_dict`.
666
667  For each key in `features` whose value is a `SparseFeature`:
668
669    * Looks up that SparseFeature's value_key and index_keys in tensor_dict.
670    * Uses those tensors to construct a single SparseTensor.
671    * Stores that SparseTensor in the output dict under the same key.
672
673  For each key in `features` whose value is a `RaggedFeature`:
674
675    * Looks up that RaggedFeature's value_key and partition keys in tensor_dict.
676    * Uses those tensors to construct a single RaggedTensor.
677    * Stores that RaggedTensor in the output dict under the same key.
678
679  For any other key in `features`:
680
681    * Copies that key and its value from tensor_dict to the output dictionary.
682
683  Args:
684    features: A `dict` mapping feature keys to `SparseFeature` or
685      `RaggedFeature` values.  Values of other types will be ignored.
686    tensor_dict: A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and
687      `RaggedTensor` values.  Expected to contain keys of the `SparseFeature`s'
688      `index_key`s and `value_key`s and mapping them to `SparseTensor`s.
689
690  Returns:
691    A `dict` mapping feature keys to `Tensor`, `SparseTensor`, and
692    `RaggedTensor` values. Similar to `tensor_dict` except each `SparseFeature`
693    in `features` results in a single `SparseTensor`; and each `RaggedFeature`
694    in `features` results in a single `RaggedTensor`.
695  """
696  tensor_dict = dict(tensor_dict)  # Do not modify argument passed in.
697  updates = {}
698  for key in sorted(features.keys()):
699    feature = features[key]
700    if isinstance(feature, SparseFeature):
701      # Construct SparseTensors for SparseFeatures
702      if isinstance(feature.index_key, str):
703        sp_ids = tensor_dict[feature.index_key]
704      else:
705        sp_ids = [tensor_dict[index_key] for index_key in feature.index_key]
706      sp_values = tensor_dict[feature.value_key]
707      updates[key] = sparse_ops.sparse_merge(
708          sp_ids,
709          sp_values,
710          vocab_size=feature.size,
711          already_sorted=feature.already_sorted)
712    elif isinstance(feature, RaggedFeature):
713      # Construct RaggedTensors for RaggedFeatures.
714      value_key = key if feature.value_key is None else feature.value_key
715      rt = tensor_dict[value_key]
716      if isinstance(rt, ragged_tensor.RaggedTensor):
717        # We processed a batch of tf.Example or tf.SequenceExample, or single
718        # tf.SequenceExample.
719        if rt.ragged_rank > 1:
720          # We're processing a batch of SequenceExample, and we effectively have
721          # two batch dimensions.  Cllapse those batch dimensions here, and
722          # restore them below (using outer_splits).
723          outer_splits = rt.row_splits
724          rt = rt.values
725        else:
726          outer_splits = None
727        for partition in reversed(feature.partitions):
728          rt = _add_batched_ragged_partition(rt, partition, tensor_dict,
729                                             key, feature.validate,
730                                             outer_splits)
731        if outer_splits is not None:
732          rt = ragged_tensor.RaggedTensor.from_row_splits(
733              rt, outer_splits, validate=feature.validate)
734      else:
735        # We processed a single tf.Example.
736        for partition in reversed(feature.partitions):
737          rt = _add_ragged_partition(rt, partition, tensor_dict,
738                                     feature.row_splits_dtype, feature.validate)
739      updates[key] = rt
740
741  # Process updates after all composite tensors have been constructed (in case
742  # multiple features use the same value_key, and one uses that key as its
743  # feature key).
744  tensor_dict.update(updates)
745
746  # Remove tensors from dictionary that were only used to construct
747  # tensors for SparseFeature or RaggedTensor.
748  for key in set(tensor_dict) - set(features):
749    del tensor_dict[key]
750  return tensor_dict
751
752
753def _add_ragged_partition(values, partition, tensor_dict, row_splits_dtype,
754                          validate):
755  """Creates a RaggedTensor from a values tensor and a partition tensor.
756
757  Args:
758    values: The values tensor for the new RaggedTensor.
759    partition: The partition configuration object.  Specifies the key that
760      should be used to look up the partition tensor (unless partition is a
761      RaggedFeature.UniformRowLength, in which case there is no partition
762      tensor).
763    tensor_dict: The dictionary mapping keys to tensors.
764    row_splits_dtype: The dtype for the partition tensor.
765    validate: Whether to validate that the values form a valid RaggedTensor.
766
767  Returns:
768    A new RaggedTensor formed from the values and partition tensors.
769  """
770  if isinstance(partition, RaggedFeature.UniformRowLength):
771    if isinstance(values, ragged_tensor.RaggedTensor):
772      length = ops.convert_to_tensor(partition.length, dtype=row_splits_dtype)
773      return ragged_tensor.RaggedTensor.from_uniform_row_length(
774          values, length, validate=validate)
775    else:
776      return array_ops.reshape(values, array_ops.concat(
777          [[-1, partition.length], array_ops.shape(values)[1:]], axis=0))
778  else:
779    partition_t = math_ops.cast(tensor_dict[partition.key], row_splits_dtype)
780    if isinstance(partition, RaggedFeature.RowSplits):
781      return ragged_tensor.RaggedTensor.from_row_splits(
782          values, partition_t, validate=validate)
783    elif isinstance(partition, RaggedFeature.RowLengths):
784      return ragged_tensor.RaggedTensor.from_row_lengths(
785          values, partition_t, validate=validate)
786    elif isinstance(partition, RaggedFeature.RowStarts):
787      return ragged_tensor.RaggedTensor.from_row_starts(
788          values, partition_t, validate=validate)
789    elif isinstance(partition, RaggedFeature.RowLimits):
790      return ragged_tensor.RaggedTensor.from_row_limits(
791          values, partition_t, validate=validate)
792    elif isinstance(partition, RaggedFeature.ValueRowIds):
793      return ragged_tensor.RaggedTensor.from_value_rowids(
794          values, partition_t, validate=validate)
795    raise ValueError("Unhandled partition type %r" % partition)
796
797
798def _add_batched_ragged_partition(rt, partition, tensor_dict, feature_key,
799                                  validate, outer_splits=None):
800  """Adds a batched ragged partition tensor to a batched ragged tensor.
801
802  Args:
803    rt: A RaggedTensor with shape [batch_size, ...].
804    partition: The partition configuration object.  Specifies the key that
805      should be used to look up the partition tensor (unless partition is a
806      RaggedFeature.UniformRowLength, in which case there is no partition
807      tensor).  The specified tensor must have shape [batch_size, ...].
808    tensor_dict: The dictionary mapping keys to tensors.
809    feature_key: The name of the feature being parsed (for error messages).
810    validate: Whether to validate that the values form a valid RaggedTensor.
811    outer_splits: If not None, then we have two batch dimensions, and this
812      is the row-splits for the collapsed batch dimension.  Every partition
813      tensor must have an outer row_splits that matches this value.
814
815  Returns:
816    A new RaggedTensor where each batch item `rt[i]` has been partitioned
817    using the `partition_t[i]`.
818  """
819  if isinstance(partition, RaggedFeature.UniformRowLength):
820    if rt.ragged_rank > 1:
821      length = ops.convert_to_tensor(partition.length, rt.row_splits.dtype)
822      return ragged_tensor.RaggedTensor.from_row_splits(
823          ragged_tensor.RaggedTensor.from_uniform_row_length(
824              rt.values, length, validate=validate),
825          rt.row_splits // length,
826          validate=validate)
827    else:
828      reshaped_vals = array_ops.reshape(rt.values, array_ops.concat(
829          [[-1, partition.length], array_ops.shape(rt.values)[1:]], axis=0))
830      return ragged_tensor.RaggedTensor.from_row_splits(
831          reshaped_vals, rt.row_splits // partition.length, validate=validate)
832
833  partition_t = tensor_dict[partition.key]
834  if partition_t.values.dtype != rt.row_splits.dtype:
835    partition_t = math_ops.cast(partition_t, rt.row_splits.dtype)
836
837  checks = []
838  if outer_splits is not None:
839    if validate:
840      checks.append(check_ops.assert_equal(
841          outer_splits, partition_t.row_splits,
842          message="Feature %s: values and partitions are not aligned"
843          % feature_key))
844    partition_t = partition_t.values
845
846  with ops.control_dependencies(checks):
847    if isinstance(partition, (RaggedFeature.RowSplits,
848                              RaggedFeature.RowLimits)):
849      if isinstance(partition, RaggedFeature.RowSplits):
850        partition_t = partition_t[:, 1:]
851      adjusted_limits = partition_t.values + array_ops.repeat(
852          rt.row_starts(), partition_t.row_lengths())
853      return partition_t.with_values(
854          ragged_tensor.RaggedTensor.from_row_limits(
855              rt.values, adjusted_limits, validate=validate))
856    elif isinstance(partition, RaggedFeature.RowStarts):
857      adjusted_starts = partition_t.values + array_ops.repeat(
858          rt.row_starts(), partition_t.row_lengths())
859      return partition_t.with_values(
860          ragged_tensor.RaggedTensor.from_row_starts(
861              rt.values, adjusted_starts, validate=validate))
862    elif isinstance(partition, RaggedFeature.RowLengths):
863      return partition_t.with_values(
864          ragged_tensor.RaggedTensor.from_row_lengths(
865              rt.values, partition_t.values, validate=validate))
866    elif isinstance(partition, RaggedFeature.ValueRowIds):
867      nrows = math_ops.maximum(  # number of rows in each batch item
868          ragged_math_ops.reduce_max(partition_t + 1, axis=1), 0)
869      adjusted_rowids = partition_t.values + array_ops.repeat(
870          math_ops.cumsum(nrows, exclusive=True), partition_t.row_lengths())
871      return ragged_tensor.RaggedTensor.from_row_lengths(
872          ragged_tensor.RaggedTensor.from_value_rowids(
873              rt.values, adjusted_rowids, validate=validate),
874          nrows,
875          validate=validate)
876
877    raise ValueError("Unhandled partition type %r" % partition)
878
879
880def _build_ragged_tensors(serialized_shape,
881                          ragged_values,
882                          ragged_row_splits,
883                          ragged_inner_splits=None):
884  """Builds RaggedTensors from the outputs of a parse op."""
885  if ragged_inner_splits is not None:
886    ragged_values = [
887        ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False)
888        for (val, split) in zip(ragged_values, ragged_inner_splits)
889    ]
890  if serialized_shape.ndims == 0:
891    return ragged_values
892  else:
893    return [
894        ragged_tensor.RaggedTensor.from_row_splits(val, split, validate=False)
895        for (val, split) in zip(ragged_values, ragged_row_splits)
896    ]
897