1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A class used to partition a sequence into contiguous subsequences ("rows").
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.framework import composite_tensor
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.framework import type_spec
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops.ragged import segment_id_ops
37
38
39#===============================================================================
40# RowPartition
41#===============================================================================
42# TODO(edloper): Consider removing row_starts and row_limits factory methods
43# and accessors from RowPartition.  In particular, these two encodings are
44# "second-class citizens": we never cache them, and if you do construct a
45# RowPartition from them then it may be more expensive than you might expect
46# (because we append a value to the beginning/end to transform them into
47# splits).  If we do remove them from RowPartition, then we would still keep
48# the from_row_starts and from_row_limits factory methods in RaggedTensor.
49
50
51class RowPartition(composite_tensor.CompositeTensor):
52  """Partitioning of a sequence of values into contiguous subsequences ("rows").
53
54  A `RowPartition` describes how a sequence with `nvals` items should be
55  divided into `nrows` contiguous subsequences ("rows").  For example, a
56  `RowPartition` could be used to partition the vector `[1, 2, 3, 4, 5]` into
57  subsequences `[[1, 2], [3], [], [4, 5]]`.  Note that `RowPartition` stores
58  information about how values are partitioned, but does not include the
59  partitioned values themselves.  `tf.RaggedTensor` is used to pair a `values`
60  tensor with one or more `RowPartition`s, providing a complete encoding for a
61  ragged tensor (i.e. a tensor with variable-length dimensions).
62
63  `RowPartition`s may be defined using several different schemes:
64
65    * `row_lengths`: an integer vector with shape `[nrows]`, which specifies
66      the length of each row.
67
68    * `row_splits`: an integer vector with shape `[nrows+1]`, specifying the
69      "split points" between each row.
70
71    * `row_starts`: an integer vector with shape `[nrows]`, which specifies
72      the start offset for each row.  Equivalent to `row_splits[:-1]`.
73
74    * `row_limits`: an integer vector with shape `[nrows]`, which specifies
75      the stop offset for each row.  Equivalent to `row_splits[1:]`.
76
77    * `value_rowids` is an integer vector with shape `[nvals]`, corresponding
78      one-to-one with sequence values, which specifies the row that each value
79      belongs to.  If the partition has empty trailing rows, then `nrows`
80      must also be specified.
81
82    * `uniform_row_length` is an integer scalar, specifying the length of every
83      row.  This scheme may only be used if all rows have the same length.
84
85  For example, the following `RowPartition`s all represent the partitioning of
86  8 values into 5 sublists as follows: `[[*, *, *, *], [], [*, *, *], [*], []]`.
87
88  >>> p1 = RowPartition.from_row_lengths([4, 0, 3, 1, 0])
89  >>> p2 = RowPartition.from_row_splits([0, 4, 4, 7, 8, 8])
90  >>> p3 = RowPartition.from_row_starts([0, 4, 4, 7, 8], nvals=8)
91  >>> p4 = RowPartition.from_row_limits([4, 4, 7, 8, 8])
92  >>> p5 = RowPartition.from_value_rowids([0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
93
94  For more information about each scheme, see the documentation for the
95  its factory method.  For additional examples, see the documentation on
96  `tf.RaggedTensor`.
97
98  ### Precomputed Encodings
99
100  `RowPartition` always stores at least one encoding of the partitioning, but
101  it can be configured to cache additional encodings as well.  This can
102  avoid unnecessary recomputation in eager mode.  (In graph mode, optimizations
103  such as common subexpression elimination will typically prevent these
104  unnecessary recomputations.)  To check which encodings are precomputed, use
105  `RowPartition.has_precomputed_<encoding>`.  To cache an additional
106  encoding, use `RowPartition.with_precomputed_<encoding>`.
107  """
108
109  #=============================================================================
110  # Constructor (private)
111  #=============================================================================
112  def __init__(self,
113               row_splits,
114               row_lengths=None,
115               value_rowids=None,
116               nrows=None,
117               uniform_row_length=None,
118               internal=False):
119    """Creates a `RowPartition` from the specified encoding tensor(s).
120
121    This constructor is private -- please use one of the following ops to
122    build `RowPartition`s:
123
124      * `RowPartition.from_row_lengths`
125      * `RowPartition.from_value_rowids`
126      * `RowPartition.from_row_splits`
127      * `RowPartition.from_row_starts`
128      * `RowPartition.from_row_limits`
129
130    Args:
131      row_splits: A 1-D integer tensor with shape `[nrows+1]`.
132      row_lengths: A 1-D integer tensor with shape `[nrows]`
133      value_rowids: A 1-D integer tensor with shape `[nvals]`.
134      nrows: A 1-D integer scalar tensor.
135      uniform_row_length: A scalar tensor.
136      internal: Private key value, required to ensure that this private
137        constructor is *only* called from the factory methods.
138
139    Raises:
140      TypeError: If a row partitioning tensor has an inappropriate dtype.
141      TypeError: If exactly one row partitioning argument was not specified.
142      ValueError: If a row partitioning tensor has an inappropriate shape.
143      ValueError: If multiple partitioning arguments are specified.
144      ValueError: If nrows is specified but value_rowids is not None.
145    """
146    if internal is not _row_partition_factory_key:
147      raise ValueError("RaggedTensor constructor is private; please use one "
148                       "of the factory methods instead (e.g., "
149                       "RaggedTensor.from_row_lengths())")
150
151    # Validate the arguments.
152    if not isinstance(row_splits, ops.Tensor):
153      raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
154                      row_splits)
155    if row_splits.dtype not in (dtypes.int32, dtypes.int64):
156      raise ValueError("Row-partitioning argument must be int32 or int64")
157
158    # Validate shapes & dtypes.
159    row_splits.shape.assert_has_rank(1)
160    row_splits.set_shape([None])
161    self._row_splits = row_splits
162
163    # Store any cached tensors.  These are used to avoid unnecessary
164    # round-trip conversions when a RaggedTensor is constructed from
165    # lengths or rowids, and we later want those lengths/rowids back.
166    for tensor in [row_lengths, value_rowids, nrows]:
167      if tensor is not None:
168        if not isinstance(tensor, ops.Tensor):
169          raise TypeError("Cached value must be a Tensor or None.")
170        elif tensor.dtype not in (dtypes.int32, dtypes.int64):
171          raise TypeError("Cached value must be int32 or int64.")
172    self._row_lengths = row_lengths
173    self._value_rowids = value_rowids
174    self._nrows = nrows
175
176    if uniform_row_length is not None:
177      if not isinstance(uniform_row_length, ops.Tensor):
178        raise TypeError("uniform_row_length must be a Tensor or None.")
179      elif uniform_row_length.dtype not in (dtypes.int32, dtypes.int64):
180        raise TypeError("uniform_row_length must be int32 or int64.")
181    self._uniform_row_length = uniform_row_length
182
183  #=============================================================================
184  # Factory Methods
185  #=============================================================================
186
187  @classmethod
188  def from_value_rowids(cls,
189                        value_rowids,
190                        nrows=None,
191                        validate=True,
192                        preferred_dtype=None):
193    """Creates a `RowPartition` with rows partitioned by `value_rowids`.
194
195    This `RowPartition` divides a sequence `values` into rows by specifying
196    which row each value should be added to:
197
198    ```python
199    partitioned_rows = [[] for _ in nrows]
200    for (value, rowid) in zip(values, value_rowids):
201      partitioned_rows[rowid].append(value)
202    ``
203
204    Args:
205      value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
206        one-to-one with `values`, and specifies each value's row index.  Must be
207        nonnegative, and must be sorted in ascending order.
208      nrows: An integer scalar specifying the number of rows.  This should be
209        specified if the `RowPartition` may containing empty training rows. Must
210        be greater than `value_rowids[-1]` (or greater than or equal to zero if
211        `value_rowids` is empty). Defaults to `value_rowids[-1]` (or zero if
212        `value_rowids` is empty).
213      validate: If true, then use assertions to check that the arguments form a
214        valid `RowPartition`.
215      preferred_dtype: The dtype to encode value_rowids if it doesn't already
216        have one. The default is tf.int64.
217
218    Returns:
219      A `RowPartition`.
220
221    Raises:
222      ValueError: If `nrows` is incompatible with `value_rowids`.
223
224    #### Example:
225
226    >>> print(RowPartition.from_value_rowids(
227    ...     value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
228    ...     nrows=4))
229    tf.RowPartition(row_splits=tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64))
230    """
231    # Local import bincount_ops to avoid import-cycle since bincount_ops
232    # imports ragged_tensor.
233    from tensorflow.python.ops import bincount_ops  # pylint: disable=g-import-not-at-top
234    if not isinstance(validate, bool):
235      raise TypeError("validate must have type bool")
236    with ops.name_scope(None, "RowPartitionFromValueRowIds",
237                        [value_rowids, nrows]):
238      value_rowids = cls._convert_row_partition(value_rowids, "value_rowids",
239                                                preferred_dtype)
240      if nrows is None:
241        const_rowids = tensor_util.constant_value(value_rowids)
242        if const_rowids is None:
243          nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1
244          const_nrows = None
245        else:
246          const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0
247          nrows = ops.convert_to_tensor(
248              const_nrows, value_rowids.dtype, name="nrows")
249      else:
250        nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows")
251        const_nrows = tensor_util.constant_value(nrows)
252        if const_nrows is not None:
253          if const_nrows < 0:
254            raise ValueError("Expected nrows >= 0; got %d" % const_nrows)
255          const_rowids = tensor_util.constant_value(value_rowids)
256          if const_rowids is not None and const_rowids.size > 0:
257            if not const_nrows >= const_rowids[-1] + 1:
258              raise ValueError(
259                  "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, "
260                  "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1]))
261
262      value_rowids.shape.assert_has_rank(1)
263      nrows.shape.assert_has_rank(0)
264
265      if validate:
266        msg = ("Arguments to from_value_rowids do not form a valid "
267               "RowPartition")
268        checks = [
269            check_ops.assert_rank(value_rowids, 1, message=msg),
270            check_ops.assert_rank(nrows, 0, message=msg),
271            check_ops.assert_non_negative(value_rowids[:1], message=msg),
272            _assert_monotonic_increasing(value_rowids, message=msg),
273            check_ops.assert_less(value_rowids[-1:], nrows, message=msg),
274        ]
275        value_rowids = control_flow_ops.with_dependencies(checks, value_rowids)
276
277      # Convert value_rowids & nrows to row_splits.
278      # Note: we don't use segment_ids_to_row_splits() here because we want
279      # to save the intermediate value `row_lengths`, so we can cache it.
280      # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the
281      # cast.
282      value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32)
283      nrows_int32 = math_ops.cast(nrows, dtypes.int32)
284      row_lengths = bincount_ops.bincount(
285          value_rowids_int32,
286          minlength=nrows_int32,
287          maxlength=nrows_int32,
288          dtype=value_rowids.dtype)
289      row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
290      if const_nrows is not None:
291        row_lengths.set_shape([const_nrows])
292        row_splits.set_shape([const_nrows + 1])
293
294      return cls(
295          row_splits=row_splits,
296          row_lengths=row_lengths,
297          value_rowids=value_rowids,
298          nrows=nrows,
299          internal=_row_partition_factory_key)
300
301  @classmethod
302  def from_row_splits(cls, row_splits, validate=True, preferred_dtype=None):
303    """Creates a `RowPartition` with rows partitioned by `row_splits`.
304
305    This `RowPartition` divides a sequence `values` into rows by indicating
306    where each row begins and ends:
307
308    ```python
309    partitioned_rows = []
310    for i in range(len(row_splits) - 1):
311      row_start = row_splits[i]
312      row_end = row_splits[i + 1]
313      partitioned_rows.append(values[row_start:row_end])
314    ```
315
316    Args:
317      row_splits: A 1-D integer tensor with shape `[nrows+1]`.  Must not be
318        empty, and must be sorted in ascending order.  `row_splits[0]` must be
319        zero.
320      validate: If true, then use assertions to check that the arguments form a
321        valid `RowPartition`.
322      preferred_dtype: If row_splits has an unspecified type, use this one. If
323        preferred_dtype is None, defaults to dtypes.int64.
324
325    Returns:
326      A `RowPartition`.
327
328    Raises:
329      ValueError: If `row_splits` is an empty list.
330    """
331    if not isinstance(validate, bool):
332      raise TypeError("validate must have type bool")
333    if isinstance(row_splits, (list, tuple)) and not row_splits:
334      raise ValueError("row_splits tensor may not be empty.")
335    if isinstance(row_splits, tensor_spec.TensorSpec):
336      return cls(row_splits=row_splits, internal=_row_partition_factory_key)
337
338    with ops.name_scope(None, "RowPartitionFromRowSplits", [row_splits]):
339      row_splits = cls._convert_row_partition(row_splits, "row_splits",
340                                              preferred_dtype)
341      row_splits.shape.assert_has_rank(1)
342
343      if validate:
344        msg = "Arguments to from_row_splits do not form a valid RaggedTensor:"
345        checks = [
346            check_ops.assert_rank(row_splits, 1, message=(msg + "rank")),
347            _assert_zero(row_splits[0], message=(msg + "zero")),
348            _assert_monotonic_increasing(
349                row_splits, message=(msg + "monotonic")),
350        ]
351        row_splits = control_flow_ops.with_dependencies(checks, row_splits)
352
353      return cls(row_splits=row_splits, internal=_row_partition_factory_key)
354
355  @classmethod
356  def from_row_lengths(cls, row_lengths, validate=True, preferred_dtype=None):
357    """Creates a `RowPartition` with rows partitioned by `row_lengths`.
358
359    This `RowPartition` divides a sequence `values` into rows by indicating
360    the length of each row:
361
362    ```python
363    partitioned_rows = [[values.pop(0) for _ in range(length)]
364                        for length in row_lengths]
365    ```
366
367    Args:
368      row_lengths: A 1-D integer tensor with shape `[nrows]`.  Must be
369        nonnegative.
370      validate: If true, then use assertions to check that the arguments form a
371        valid `RowPartition`.
372      preferred_dtype: If row_lengths has an unspecified type, use this one. If
373        preferred_dtype is None, defaults to dtypes.int64.
374
375    Returns:
376      A `RowPartition`.
377    """
378    if not isinstance(validate, bool):
379      raise TypeError("validate must have type bool")
380    with ops.name_scope(None, "RowPartitionFromRowLengths", [row_lengths]):
381      row_lengths = cls._convert_row_partition(row_lengths, "row_lengths",
382                                               preferred_dtype)
383      row_lengths.shape.assert_has_rank(1)
384
385      if validate:
386        msg = "Arguments to from_row_lengths do not form a valid RowPartition"
387        checks = [
388            check_ops.assert_rank(row_lengths, 1, message=msg),
389            check_ops.assert_non_negative(row_lengths, message=msg),
390        ]
391        row_lengths = control_flow_ops.with_dependencies(checks, row_lengths)
392
393      row_limits = math_ops.cumsum(row_lengths)
394      row_splits = array_ops.concat([[0], row_limits], axis=0)
395      return cls(
396          row_splits=row_splits,
397          row_lengths=row_lengths,
398          internal=_row_partition_factory_key)
399
400  @classmethod
401  def from_row_starts(cls,
402                      row_starts,
403                      nvals,
404                      validate=True,
405                      preferred_dtype=None):
406    """Creates a `RowPartition` with rows partitioned by `row_starts`.
407
408    Equivalent to: `from_row_splits(concat([row_starts, nvals], axis=0))`.
409
410    Args:
411      row_starts: A 1-D integer tensor with shape `[nrows]`.  Must be
412        nonnegative and sorted in ascending order.  If `nrows>0`, then
413        `row_starts[0]` must be zero.
414      nvals: A scalar tensor indicating the number of values.
415      validate: If true, then use assertions to check that the arguments form a
416        valid `RowPartition`.
417      preferred_dtype: If row_limits has an unspecified type, use this one. If
418        preferred_dtype is None, defaults to dtypes.int64.
419
420    Returns:
421      A `RowPartition`.
422    """
423    if not isinstance(validate, bool):
424      raise TypeError("validate must have type bool")
425    with ops.name_scope(None, "RowPartitionFromRowStarts", [row_starts]):
426      row_starts = cls._convert_row_partition(row_starts, "row_starts",
427                                              preferred_dtype)
428      row_starts.shape.assert_has_rank(1)
429      nvals = math_ops.cast(nvals, row_starts.dtype)
430      if validate:
431        msg = "Arguments to from_row_starts do not form a valid RaggedTensor"
432        checks = [
433            check_ops.assert_rank(row_starts, 1, message=msg),
434            _assert_zero(row_starts[:1], message=msg),
435            _assert_monotonic_increasing(row_starts, message=msg),
436            check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg),
437        ]
438        row_starts = control_flow_ops.with_dependencies(checks, row_starts)
439
440      row_splits = array_ops.concat([row_starts, [nvals]], axis=0)
441      return cls(row_splits=row_splits, internal=_row_partition_factory_key)
442
443  @classmethod
444  def from_row_limits(cls, row_limits, validate=True, preferred_dtype=None):
445    """Creates a `RowPartition` with rows partitioned by `row_limits`.
446
447    Equivalent to: `from_row_splits(values, concat([0, row_limits], axis=0))`.
448
449    Args:
450      row_limits: A 1-D integer tensor with shape `[nrows]`.  Must be sorted in
451        ascending order.
452      validate: If true, then use assertions to check that the arguments form a
453        valid `RowPartition`.
454      preferred_dtype: If row_limits has an unspecified type, use this one. If
455        preferred_dtype is None, defaults to dtypes.int64.
456
457    Returns:
458      A `RowPartition`.
459    """
460    if not isinstance(validate, bool):
461      raise TypeError("validate must have type bool")
462    with ops.name_scope(None, "RowPartitionFromRowLimits", [row_limits]):
463      row_limits = cls._convert_row_partition(row_limits, "row_limits",
464                                              preferred_dtype)
465      row_limits.shape.assert_has_rank(1)
466
467      if validate:
468        msg = "Arguments to from_row_limits do not form a valid RaggedTensor"
469        checks = [
470            check_ops.assert_rank(row_limits, 1, message=msg),
471            check_ops.assert_non_negative(row_limits[:1], message=msg),
472            _assert_monotonic_increasing(row_limits, message=msg),
473        ]
474        row_limits = control_flow_ops.with_dependencies(checks, row_limits)
475
476      zero = array_ops.zeros([1], row_limits.dtype)
477      row_splits = array_ops.concat([zero, row_limits], axis=0)
478      return cls(row_splits=row_splits, internal=_row_partition_factory_key)
479
480  # TODO(edloper): Make nvals optional: user must specify at least one of
481  # {nvals, nrows}, but they can pick which one to specify.
482  @classmethod
483  def from_uniform_row_length(cls,
484                              uniform_row_length,
485                              nvals,
486                              nrows=None,
487                              validate=True,
488                              preferred_dtype=None):
489    """Creates a `RowPartition` with rows partitioned by `uniform_row_length`.
490
491    This `RowPartition` divides a sequence `values` into rows that all have
492    the same length:
493
494    ```python
495    partitioned_rows = [[values.pop(0) for _ in range(uniform_row_length)]
496             for _ in range(nrows)]
497    ```
498
499    Args:
500      uniform_row_length: A scalar integer tensor.  Must be nonnegative. The
501        size of the outer axis of `values` must be evenly divisible by
502        `uniform_row_length`.
503      nvals: a non-negative scalar integer tensor for the number of values.
504      nrows: The number of rows in the constructed RowPartition.  If not
505        specified, then it defaults to `nvals/uniform_row_length` (or `0` if
506        `uniform_row_length==0`).  `nrows` only needs to be specified if
507        `uniform_row_length` might be zero.  `uniform_row_length*nrows` must be
508        `nvals`.
509      validate: If true, then use assertions to check that the arguments form a
510        valid `RowPartition`.
511      preferred_dtype: if uniform_row_length has no dtype, use this one.
512
513    Returns:
514      A `RowPartition`.
515    """
516    if not isinstance(validate, bool):
517      raise TypeError("validate must have type bool")
518    with ops.name_scope(None, "RowPartitionFromUniformRowLength",
519                        [uniform_row_length, nrows]):
520      uniform_row_length = cls._convert_row_partition(uniform_row_length,
521                                                      "uniform_row_length",
522                                                      preferred_dtype)
523      uniform_row_length.shape.assert_has_rank(0)
524
525      # Find nrows.
526      const_row_length = tensor_util.constant_value(uniform_row_length)
527      if nrows is None:
528        if const_row_length is None:
529          # Avoid division by zero if uniform_row_length==0 (and nvals==0).
530          rowlen_or_1 = math_ops.maximum(
531              uniform_row_length,
532              constant_op.constant(1, uniform_row_length.dtype))
533          nrows = nvals // rowlen_or_1
534        elif const_row_length == 0:
535          nrows = 0
536        else:
537          nrows = nvals // const_row_length
538      nrows = ops.convert_to_tensor(
539          nrows, uniform_row_length.dtype, name="nrows")
540      const_nrows = tensor_util.constant_value(nrows)
541      const_nvals = tensor_util.constant_value(nvals)
542
543      # Find row_splits.
544      if const_nrows is not None and const_row_length is not None:
545        row_splits = [v * const_row_length for v in range(const_nrows + 1)]
546        row_splits = constant_op.constant(row_splits, uniform_row_length.dtype)
547      else:
548        row_splits = math_ops.range(nrows + 1) * uniform_row_length
549
550      if validate:
551        checks = []
552
553        if (const_nrows is None or const_row_length is None or
554            const_nvals is None):
555          checks.append(
556              check_ops.assert_equal(
557                  nrows * uniform_row_length, nvals,
558                  ("uniform_row_length", uniform_row_length, "times nrows",
559                   nrows, "must equal nvals", nvals)))
560        else:
561          if const_nrows * const_row_length != const_nvals:
562            raise ValueError(
563                "uniform_row_length=%d times nrows=%d must equal nvals=%d" %
564                (const_row_length, const_nrows, const_nvals))
565
566        if uniform_row_length.shape.rank is None:
567          checks.append(
568              check_ops.assert_rank(
569                  uniform_row_length,
570                  0,
571                  message="uniform_row_length must be a scalar."))
572
573        const_row_length = tensor_util.constant_value(uniform_row_length)
574        if const_row_length is None:
575          checks.append(
576              check_ops.assert_greater_equal(
577                  uniform_row_length,
578                  constant_op.constant(0, uniform_row_length.dtype),
579                  message="uniform_row_length must be >= 0."))
580        else:
581          if const_row_length < 0:
582            raise ValueError("uniform_row_length must be >= 0.")
583
584        row_splits = control_flow_ops.with_dependencies(checks, row_splits)
585
586      return cls(
587          row_splits=row_splits,
588          uniform_row_length=uniform_row_length,
589          nrows=nrows,
590          internal=_row_partition_factory_key)
591
592  @classmethod
593  def _convert_row_partition(cls, partition, name, preferred_dtype):
594    """Converts `partition` to Tensors.
595
596    Args:
597      partition: A row-partitioning tensor for the `RowPartition` being
598        constructed.  I.e., one of: row_splits, row_lengths, row_starts,
599        row_limits, value_rowids, uniform_row_length.
600      name: The name of the row-partitioning tensor.
601      preferred_dtype: If partition has no dtype, give it this one. If
602        no dtype is specified, use dtypes.int64.
603
604    Returns:
605      A tensor equivalent to partition.
606
607    Raises:
608      ValueError: if dtype is not int32 or int64.
609    """
610    if preferred_dtype is None:
611      preferred_dtype = dtypes.int64
612    if isinstance(partition, np.ndarray) and partition.dtype == np.int32:
613      partition = ops.convert_to_tensor(partition, name=name)
614    else:
615      partition = ops.convert_to_tensor(
616          partition, preferred_dtype=preferred_dtype, name=name)
617    if partition.dtype not in (dtypes.int32, dtypes.int64):
618      raise ValueError("%s must have dtype int32 or int64" % name)
619
620    return partition
621
622  def with_dependencies(self, dependencies):
623    """Returns a new RowPartition equal to self with control dependencies.
624
625    Specifically, self._row_splits is gated by the given control dependencies.
626    Used to add sanity checks to the constructors.
627
628    Args:
629      dependencies: a list of tensors to use as dependencies.
630
631    Returns:
632      A new RowPartition object.
633    """
634    new_row_splits = control_flow_ops.with_dependencies(dependencies,
635                                                        self._row_splits)
636    return RowPartition(
637        row_splits=new_row_splits,
638        row_lengths=self._row_lengths,
639        value_rowids=self._value_rowids,
640        nrows=self._nrows,
641        uniform_row_length=self._uniform_row_length,
642        internal=_row_partition_factory_key)
643
644  #=============================================================================
645  # Accessors
646  #=============================================================================
647
648  @property
649  def dtype(self):
650    """The `DType` used to encode the row partition (either int32 or int64)."""
651    return self._row_splits.dtype
652
653  def row_splits(self):
654    """Returns the row-split indices for this row partition.
655
656    `row_splits` specifies where the values for each row begin and end.
657    In particular, the values for row `i` are stored in the slice
658    `values[row_splits[i]:row_splits[i+1]]`.
659
660    Returns:
661      A 1-D integer `Tensor` with shape `[self.nrows+1]`.
662      The returned tensor is non-empty, and is sorted in ascending order.
663      `self.row_splits()[0] == 0`.
664      `self.row_splits()[-1] == self.nvals()`.
665    """
666    return self._row_splits
667
668  def value_rowids(self):
669    """Returns the row indices for this row partition.
670
671    `value_rowids` specifies the row index fo reach value.  In particular,
672    `value_rowids[i]` is the row index for `values[i]`.
673
674    Returns:
675      A 1-D integer `Tensor` with shape `[self.nvals()]`.
676      The returned tensor is nonnegative, and is sorted in ascending order.
677    """
678    if self._value_rowids is not None:
679      return self._value_rowids
680    return segment_id_ops.row_splits_to_segment_ids(self._row_splits)
681
682  def nvals(self, out_type=None):
683    """Returns the number of values partitioned by this `RowPartition`.
684
685    If the sequence partitioned by this `RowPartition` is a tensor, then
686    `nvals` is the size of that tensor's outermost dimension -- i.e.,
687    `nvals == values.shape[0]`.
688
689    Args:
690      out_type: `dtype` for the returned tensor.  Defaults to `self.dtype`.
691
692    Returns:
693      scalar integer Tensor
694    """
695    if out_type is None:
696      return self._row_splits[-1]
697    else:
698      out_type = dtypes.as_dtype(out_type)
699      return math_ops.cast(self._row_splits[-1], dtype=out_type)
700
701  def nrows(self, out_type=None):
702    """Returns the number of rows created by this `RowPartition`.
703
704    Args:
705      out_type: `dtype` for the returned tensor.  Defaults to `self.dtype`.
706
707    Returns:
708      scalar integer Tensor
709    """
710    if out_type is None:
711      out_type = self.dtype
712    else:
713      out_type = dtypes.as_dtype(out_type)
714    if self._nrows is not None:
715      return math_ops.cast(self._nrows, out_type)
716    nsplits = tensor_shape.dimension_at_index(self._row_splits.shape, 0)
717    if nsplits.value is None:
718      return array_ops.shape(self._row_splits, out_type=out_type)[0] - 1
719    else:
720      return constant_op.constant(nsplits.value - 1, dtype=out_type)
721
722  def uniform_row_length(self):
723    """Returns the length of each row in this partition, if rows are uniform.
724
725    If all rows in this `RowPartition` have the same length, then this returns
726    that length as a scalar integer `Tensor`.  Otherwise, it returns `None`.
727
728    Returns:
729      scalar Tensor with `type=self.dtype`, or `None`.
730    """
731    return self._uniform_row_length
732
733  def row_starts(self):
734    """Returns the start indices for rows in this row partition.
735
736    These indices specify where the values for each row begin.
737    `partition.row_starts()` is equal to `partition.row_splits()[:-1]`.
738
739    Returns:
740      A 1-D integer Tensor with shape `[self.nrows()]`.
741      The returned tensor is nonnegative, and is sorted in ascending order.
742      `self.row_starts()[0] == 0`.
743      `self.row_starts()[-1] <= self.nvals()`.
744    """
745    return self._row_splits[:-1]
746
747  def row_limits(self):
748    """Returns the limit indices for rows in this row partition.
749
750    These indices specify where the values for each row end.
751    `partition.row_limits()` is equal to `partition.row_splits()[:-1]`.
752
753    Returns:
754      A 1-D integer Tensor with shape `[self.nrows]`.
755      The returned tensor is nonnegative, and is sorted in ascending order.
756      `self.row_limits()[-1] == self.nvals()`.
757    """
758    return self._row_splits[1:]
759
760  def row_lengths(self):
761    """Returns the lengths of rows in this `RowPartition`.
762
763    Returns:
764      A 1-D integer Tensor with shape `[self.nrows]`.
765      The returned tensor is nonnegative.
766      `tf.reduce_sum(self.row_lengths) == self.nvals()`.
767    """
768    if self._row_lengths is not None:
769      return self._row_lengths
770    splits = self._row_splits
771    return splits[1:] - splits[:-1]
772
773  @property
774  def static_nrows(self):
775    """The number of rows in this partition, if statically known.
776
777    ```python
778    self.row_lengths().shape == [self.static_nrows]
779    self.row_starts().shape == [self.static_nrows]
780    self.row_limits().shape == [self.static_nrows]
781    self.row_splits().shape == [self.static_nrows + 1]
782    ```
783
784    Returns:
785      The number of rows in this partition as an `int` (if statically known);
786      or `None` (otherwise).
787    """
788    if self._row_splits is not None:
789      nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1
790      if nrows.value is not None:
791        return nrows
792    if self._row_lengths is not None:
793      nrows = tensor_shape.dimension_at_index(self._row_lengths.shape, 0)
794      if nrows.value is not None:
795        return nrows
796    if self._nrows is not None:
797      return tensor_shape.Dimension(tensor_util.constant_value(self._nrows))
798    return None
799
800  @property
801  def static_nvals(self):
802    """The number of values in this partition, if statically known.
803
804    ```python
805    self.value_rowids().shape == [self.static_vals]
806    ```
807
808    Returns:
809      The number of values in this partition as an `int` (if statically known);
810      or `None` (otherwise).
811    """
812    if self._value_rowids is not None:
813      nvals = tensor_shape.dimension_at_index(self._value_rowids.shape, 0)
814      if nvals.value is not None:
815        return nvals.value
816    return None
817
818  @property
819  def static_uniform_row_length(self):
820    """The number of values in each row of this partition, if statically known.
821
822    Returns:
823      The number of values in each row of this partition as an `int` (if
824      statically known); or `None` (otherwise).
825    """
826    if self._uniform_row_length is not None:
827      return tensor_util.constant_value(self._uniform_row_length)
828    return None
829
830  #=============================================================================
831  # Transformation
832  #=============================================================================
833
834  def with_row_splits_dtype(self, dtype):
835    """Returns a copy of this RowPartition with the given `row_splits` dtype.
836
837    For RaggedTensors with multiple ragged dimensions, the `row_splits` for all
838    nested `RaggedTensor` objects are cast to the given dtype.
839
840    Args:
841      dtype: The dtype for `row_splits`.  One of `tf.int32` or `tf.int64`.
842
843    Returns:
844      A copy of this RaggedTensor, with the `row_splits` cast to the given
845      type.
846    """
847    dtype = dtypes.as_dtype(dtype)
848    if dtype not in (dtypes.int32, dtypes.int64):
849      raise ValueError("dtype must be int32 or int64")
850    if self.dtype == dtype:
851      return self
852
853    return RowPartition(
854        row_splits=_cast_if_not_none(self._row_splits, dtype),
855        row_lengths=_cast_if_not_none(self._row_lengths, dtype),
856        value_rowids=_cast_if_not_none(self._value_rowids, dtype),
857        nrows=_cast_if_not_none(self._nrows, dtype),
858        uniform_row_length=_cast_if_not_none(self._uniform_row_length, dtype),
859        internal=_row_partition_factory_key)
860
861  #=============================================================================
862  # String Encoding
863  #=============================================================================
864
865  def __repr__(self):
866    return "tf.RowPartition(row_splits=%s)" % (self._row_splits)
867
868  #=============================================================================
869  # Precomputed Encodings
870  #=============================================================================
871
872  def has_precomputed_row_splits(self):
873    """Returns true if `row_splits` has already been computed.
874
875    If true, then `self.row_splits()` will return its value without calling
876    any TensorFlow ops.
877    """
878    return self._row_splits is not None
879
880  def has_precomputed_row_lengths(self):
881    """Returns true if `row_lengths` has already been computed.
882
883    If true, then `self.row_lengths()` will return its value without calling
884    any TensorFlow ops.
885    """
886    return self._row_lengths is not None
887
888  def has_precomputed_value_rowids(self):
889    """Returns true if `value_rowids` has already been computed.
890
891    If true, then `self.value_rowids()` will return its value without calling
892    any TensorFlow ops.
893    """
894    return self._value_rowids is not None
895
896  def has_precomputed_nrows(self):
897    """Returns true if `nrows` has already been computed.
898
899    If true, then `self.nrows()` will return its value without calling
900    any TensorFlow ops.
901    """
902    return self._nrows is not None
903
904  def with_precomputed_row_splits(self):
905    """Returns a copy of `self` with `row_splits` precomputed."""
906    return RowPartition(
907        row_splits=self.row_splits(),
908        row_lengths=self._row_lengths,
909        value_rowids=self._value_rowids,
910        nrows=self._nrows,
911        uniform_row_length=self._uniform_row_length,
912        internal=_row_partition_factory_key)
913
914  def with_precomputed_row_lengths(self):
915    """Returns a copy of `self` with `row_lengths` precomputed."""
916    return RowPartition(
917        row_splits=self._row_splits,
918        row_lengths=self.row_lengths(),
919        value_rowids=self._value_rowids,
920        nrows=self._nrows,
921        uniform_row_length=self._uniform_row_length,
922        internal=_row_partition_factory_key)
923
924  def with_precomputed_value_rowids(self):
925    """Returns a copy of `self` with `value_rowids` precomputed."""
926    return RowPartition(
927        row_splits=self._row_splits,
928        row_lengths=self._row_lengths,
929        value_rowids=self.value_rowids(),
930        nrows=self._nrows,
931        uniform_row_length=self._uniform_row_length,
932        internal=_row_partition_factory_key)
933
934  def with_precomputed_nrows(self):
935    """Returns a copy of `self` with `nrows` precomputed."""
936    return RowPartition(
937        row_splits=self._row_splits,
938        row_lengths=self._row_lengths,
939        value_rowids=self._value_rowids,
940        nrows=self.nrows(),
941        uniform_row_length=self._uniform_row_length,
942        internal=_row_partition_factory_key)
943
944  def merge_precomputed_encodings(self, other, validate=True):
945    """Returns a RowPartition that merges encodings from `self` and `other`.
946
947    Requires that `self` and `other` describe the same partition.
948
949    Args:
950      other: A `RowPartition` that encodes the same partition as `self`.
951      validate: If true, then add runtime checks to verify that `self` and
952        `other` encode the same row partition.
953
954    Returns:
955      A `RowPartition`.
956    """
957    # pylint: disable=protected-access
958    if (self is other or  # Fast path if row partitions are equal.
959        (self._row_splits is other._row_splits and
960         self._row_lengths is other._row_lengths and
961         self._value_rowids is other._value_rowids and
962         self._nrows is other._nrows and
963         self._uniform_row_length is other._uniform_row_length)):
964      return self
965
966    # Merge the component tensors.  We only need to validate one encoding.
967    # We merge less-expensive encodings first (to avoid expensive validation).
968    nrows, nrows_validated = _merge_tensors(self._nrows, other._nrows, "nrows",
969                                            validate)
970    uniform_row_length, uniform_row_length_validated = _merge_tensors(
971        self._uniform_row_length, other._uniform_row_length,
972        "uniform_row_length", validate)
973    if uniform_row_length_validated and nrows_validated:
974      validate = False  # Validation complete.
975    row_splits, row_splits_validated = _merge_tensors(self._row_splits,
976                                                      other._row_splits,
977                                                      "row_splits", validate)
978    if row_splits_validated:
979      validate = False  # Validation complete.
980    row_lengths, row_lengths_validated = _merge_tensors(self._row_lengths,
981                                                        other._row_lengths,
982                                                        "row_lengths", validate)
983    if row_lengths_validated:
984      validate = False  # Validation complete.
985    value_rowids, value_rowids_validated = _merge_tensors(
986        self._value_rowids, other._value_rowids, "value_rowids", validate)
987    if value_rowids_validated and nrows_validated:
988      validate = False  # Validation complete.
989    # TODO(edloper): If we make the row_splits encoding optional, then there
990    # will be cases where we need to do validation at this point -- e.g. if
991    # self has only row_splits and other has only value_rowids.  But for
992    # now, we are guaranteed to have done validation by this point.
993
994    # Avoid creating new RowPartition objects if we don't need to.
995    if (row_splits is self._row_splits and row_lengths is self._row_lengths and
996        value_rowids is self._value_rowids and nrows is self._nrows and
997        uniform_row_length is self._uniform_row_length):
998      return self
999    if (row_splits is other._row_splits and
1000        row_lengths is other._row_lengths and
1001        value_rowids is other._value_rowids and nrows is other._nrows and
1002        uniform_row_length is other._uniform_row_length):
1003      return other
1004
1005    return RowPartition(
1006        row_splits=row_splits,
1007        row_lengths=row_lengths,
1008        value_rowids=value_rowids,
1009        nrows=nrows,
1010        uniform_row_length=uniform_row_length,
1011        internal=_row_partition_factory_key)
1012
1013  #=============================================================================
1014  # Composite Tensor
1015  #=============================================================================
1016
1017  @property
1018  def _type_spec(self):
1019    return RowPartitionSpec.from_value(self)
1020
1021
1022#===============================================================================
1023# RowPartitionSpec
1024#===============================================================================
1025# TODO(edloper): Consider refactoring RowPartitionSpec to allow any combination
1026# of precomputed row-partition encodings (rather than always using row_splits).
1027
1028
1029class RowPartitionSpec(type_spec.TypeSpec):
1030  """Type specification for a `tf.RowPartition`."""
1031
1032  __slots__ = ["_nrows", "_nvals", "_uniform_row_length", "_dtype"]
1033
1034  value_type = property(lambda self: RowPartition)
1035
1036  def __init__(self,
1037               nrows=None,
1038               nvals=None,
1039               uniform_row_length=None,
1040               dtype=dtypes.int64):
1041    """Constructs a new RowPartitionSpec.
1042
1043    Args:
1044      nrows: The number of rows in the RowPartition, or `None` if unspecified.
1045      nvals: The number of values partitioned by the RowPartition, or `None` if
1046        unspecified.
1047      uniform_row_length: The number of values in each row for this
1048        RowPartition, or `None` if rows are ragged or row length is unspecified.
1049      dtype: The data type used to encode the partition.  One of `tf.int64` or
1050        `tf.int32`.
1051    """
1052    # Wrap dimension sizes in 1D TensorShapes so the default implementations
1053    # of TypeSpec methods such as `is_compatile_with` will work.
1054    nrows = tensor_shape.TensorShape([nrows])
1055    nvals = tensor_shape.TensorShape([nvals])
1056    if not isinstance(uniform_row_length, tensor_shape.TensorShape):
1057      uniform_row_length = tensor_shape.TensorShape([uniform_row_length])
1058    else:
1059      uniform_row_length = uniform_row_length.with_rank(1)
1060
1061    self._nrows = nrows
1062    self._nvals = nvals
1063    self._uniform_row_length = uniform_row_length
1064    self._dtype = dtypes.as_dtype(dtype)
1065    if self._dtype not in (dtypes.int32, dtypes.int64):
1066      raise ValueError("dtype must be tf.int32 or tf.int64")
1067
1068    # Check dimension consistency, & infer dimensions when possible.
1069    nrows = tensor_shape.dimension_value(nrows[0])
1070    nvals = tensor_shape.dimension_value(nvals[0])
1071    ncols = tensor_shape.dimension_value(uniform_row_length[0])
1072    if nrows == 0:  # no rows -> no values.
1073      if nvals is None:
1074        self._nvals = tensor_shape.TensorShape([0])
1075      elif nvals != 0:
1076        raise ValueError("nvals=%s is not compatible with nrows=%s" %
1077                         (nvals, nrows))
1078    if ncols == 0:  # there are no values in each row -> no values.
1079      if nvals is None:
1080        self._nvals = tensor_shape.TensorShape([0])
1081      elif nvals != 0:
1082        raise ValueError("nvals=%s is not compatible with uniform_row_length"
1083                         "=%s" % (nvals, uniform_row_length))
1084    if ncols is not None and nvals is not None:
1085      if ncols != 0 and nvals % ncols != 0:
1086        raise ValueError("nvals=%s is not compatible with uniform_row_length"
1087                         "=%s (doesn't divide evenly)" % (nvals, ncols))
1088      if nrows is not None and nvals != ncols * nrows:
1089        raise ValueError("nvals=%s is not compatible with nrows=%s and "
1090                         "uniform_row_length=%s" % (nvals, nrows, ncols))
1091      if nrows is None and ncols != 0:
1092        self._nrows = tensor_shape.TensorShape([nvals // ncols])
1093    if ncols is not None and nrows is not None and nvals is None:
1094      self._nvals = tensor_shape.TensorShape([ncols * nrows])
1095
1096  def is_compatible_with(self, other):
1097    if not super(RowPartitionSpec, self).is_compatible_with(other):
1098      return False
1099    nrows = self._nrows.merge_with(other.nrows)
1100    nvals = self._nvals.merge_with(other.nvals)
1101    ncols = self._uniform_row_length.merge_with(other.uniform_row_length)
1102    return self._dimensions_compatible(nrows, nvals, ncols)
1103
1104  def _serialize(self):
1105    return (self._nrows, self._nvals, self._uniform_row_length, self._dtype)
1106
1107  @classmethod
1108  def _deserialize(cls, serialization):
1109    # Remove TensorShape wrappers from serialization.
1110    (nrows, nvals, uniform_row_length, dtype) = serialization
1111    nrows = tensor_shape.dimension_value(nrows[0])
1112    nvals = tensor_shape.dimension_value(nvals[0])
1113    return cls(nrows, nvals, uniform_row_length, dtype)
1114
1115  @property
1116  def nrows(self):
1117    return tensor_shape.dimension_value(self._nrows[0])
1118
1119  @property
1120  def nvals(self):
1121    return tensor_shape.dimension_value(self._nvals[0])
1122
1123  @property
1124  def uniform_row_length(self):
1125    return tensor_shape.dimension_value(self._uniform_row_length[0])
1126
1127  @property
1128  def dtype(self):
1129    return self._dtype
1130
1131  @property
1132  def _component_specs(self):
1133    row_splits_shape = tensor_shape.TensorShape(
1134        [tensor_shape.dimension_at_index(self._nrows, 0) + 1])
1135    return tensor_spec.TensorSpec(row_splits_shape, self._dtype)
1136
1137  def _to_components(self, value):
1138    return value.row_splits()
1139
1140  def _from_components(self, tensor):
1141    return RowPartition.from_row_splits(tensor, validate=False)
1142
1143  @classmethod
1144  def from_value(cls, value):
1145    if not isinstance(value, RowPartition):
1146      raise TypeError("Expected `value` to be a `RowPartition`")
1147    return cls(value.static_nrows, value.static_nvals,
1148               value.static_uniform_row_length, value.dtype)
1149
1150  def __repr__(self):
1151    return ("RowPartitionSpec(nrows=%s, nvals=%s, uniform_row_length=%s, "
1152            "dtype=%r)" % (self.nrows, self.nvals, self.uniform_row_length,
1153                           self.dtype))
1154
1155  @staticmethod
1156  def _dimensions_compatible(nrows, nvals, uniform_row_length):
1157    """Returns true if the given dimensions are compatible."""
1158    nrows = tensor_shape.dimension_value(nrows[0])
1159    nvals = tensor_shape.dimension_value(nvals[0])
1160    ncols = tensor_shape.dimension_value(uniform_row_length[0])
1161    if nrows == 0 and nvals not in (0, None):
1162      return False  # can't have values if we have no rows.
1163    if ncols == 0 and nvals not in (0, None):
1164      return False  # can't have values if we have no values in each row.
1165    if ncols is not None and nvals is not None:
1166      if ncols != 0 and nvals % ncols != 0:
1167        return False  # rows aren't uniform.
1168      if nrows is not None and nvals != ncols * nrows:
1169        return False  # inconsistent number of values.
1170    return True
1171
1172
1173#===============================================================================
1174# Helper Functions
1175#===============================================================================
1176
1177
1178def _assert_monotonic_increasing(tensor, message=None):
1179  return check_ops.assert_non_negative(
1180      tensor[1:] - tensor[:-1], message=message)
1181
1182
1183def _assert_zero(tensor, message=None):
1184  return check_ops.assert_equal(
1185      tensor, constant_op.constant(0, dtype=tensor.dtype), message=message)
1186
1187
1188def _cast_if_not_none(tensor, dtype):
1189  return None if tensor is None else math_ops.cast(tensor, dtype)
1190
1191
1192def _merge_tensors(t1, t2, name, validate):
1193  """Merge two optional Tensors with equal values into a single Tensor.
1194
1195  Args:
1196    t1: tf.Tensor or None
1197    t2: tf.Tensor or None
1198    name: A name for the tensors (for error messages)
1199    validate: If true, then check that `t1` is compatible with `t2` (if both are
1200      non-None).
1201
1202  Returns:
1203    A pair `(merged_value, validated)`:
1204      * `merged_value` is `t1` if it is not None; or `t2` otherwise.
1205      * `validated` is true if we validated that t1 and t2 are equal (either
1206        by adding a check, or because t1 is t2).
1207  """
1208  if t1 is None:
1209    return t2, False
1210  elif t2 is None:
1211    return t1, False
1212  elif t1 is t2:
1213    return t1, True
1214  else:
1215    err_msg = ("RowPartition.merge_precomuted_encodings: partitions "
1216               "have incompatible %s" % name)
1217    if not t1.shape.is_compatible_with(t2.shape):
1218      raise ValueError(err_msg)
1219    if validate:
1220      checks = [check_ops.assert_equal(t1, t2, message=err_msg)]
1221      return control_flow_ops.with_dependencies(checks, t1), True
1222    else:
1223      return t1, False
1224
1225
1226_row_partition_factory_key = object()  # unique private object
1227