1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for storing ragged tensors and their values."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import operator
23
24import numpy as np
25
26from tensorflow.python import tf2
27from tensorflow.python.client import session
28from tensorflow.python.framework import composite_tensor
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework import type_spec
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import check_ops
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import gen_ragged_conversion_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops.ragged import ragged_config
43from tensorflow.python.ops.ragged import ragged_tensor_value
44from tensorflow.python.ops.ragged import ragged_util
45from tensorflow.python.ops.ragged.row_partition import RowPartition
46from tensorflow.python.types import internal as internal_types
47from tensorflow.python.util import dispatch
48from tensorflow.python.util.tf_export import tf_export
49from tensorflow.tools.docs import doc_controls
50
51# pylint: disable=protected-access
52_convert_row_partition = RowPartition._convert_row_partition
53# pylint: enable=protected-access
54
55#===============================================================================
56# RaggedTensor
57#===============================================================================
58
59
60@tf_export("RaggedTensor")
61class RaggedTensor(composite_tensor.CompositeTensor,
62                   internal_types.NativeObject):
63  """Represents a ragged tensor.
64
65  A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are
66  dimensions whose slices may have different lengths.  For example, the inner
67  (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged,
68  since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths.
69  Dimensions whose slices all have the same length are called *uniform
70  dimensions*.  The outermost dimension of a `RaggedTensor` is always uniform,
71  since it consists of a single slice (and so there is no possibility for
72  differing slice lengths).
73
74  The total number of dimensions in a `RaggedTensor` is called its *rank*,
75  and the number of ragged dimensions in a `RaggedTensor` is called its
76  *ragged-rank*.  A `RaggedTensor`'s ragged-rank is fixed at graph creation
77  time: it can't depend on the runtime values of `Tensor`s, and can't vary
78  dynamically for different session runs.
79
80  Note that the `__init__` constructor is private. Please use one of the
81  following methods to construct a `RaggedTensor`:
82
83  * `tf.RaggedTensor.from_row_lengths`
84  * `tf.RaggedTensor.from_value_rowids`
85  * `tf.RaggedTensor.from_row_splits`
86  * `tf.RaggedTensor.from_row_starts`
87  * `tf.RaggedTensor.from_row_limits`
88  * `tf.RaggedTensor.from_nested_row_splits`
89  * `tf.RaggedTensor.from_nested_row_lengths`
90  * `tf.RaggedTensor.from_nested_value_rowids`
91
92  ### Potentially Ragged Tensors
93
94  Many ops support both `Tensor`s and `RaggedTensor`s
95  (see [tf.ragged](https://www.tensorflow.org/api_docs/python/tf/ragged) for a
96  full listing). The term "potentially ragged tensor" may be used to refer to a
97  tensor that might be either a `Tensor` or a `RaggedTensor`.  The ragged-rank
98  of a `Tensor` is zero.
99
100  ### Documenting RaggedTensor Shapes
101
102  When documenting the shape of a RaggedTensor, ragged dimensions can be
103  indicated by enclosing them in parentheses.  For example, the shape of
104  a 3-D `RaggedTensor` that stores the fixed-size word embedding for each
105  word in a sentence, for each sentence in a batch, could be written as
106  `[num_sentences, (num_words), embedding_size]`.  The parentheses around
107  `(num_words)` indicate that dimension is ragged, and that the length
108  of each element list in that dimension may vary for each item.
109
110  ### Component Tensors
111
112  Internally, a `RaggedTensor` consists of a concatenated list of values that
113  are partitioned into variable-length rows.  In particular, each `RaggedTensor`
114  consists of:
115
116    * A `values` tensor, which concatenates the variable-length rows into a
117      flattened list.  For example, the `values` tensor for
118      `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is `[3, 1, 4, 1, 5, 9, 2, 6]`.
119
120    * A `row_splits` vector, which indicates how those flattened values are
121      divided into rows.  In particular, the values for row `rt[i]` are stored
122      in the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
123
124  Example:
125
126  >>> print(tf.RaggedTensor.from_row_splits(
127  ...       values=[3, 1, 4, 1, 5, 9, 2, 6],
128  ...       row_splits=[0, 4, 4, 7, 8, 8]))
129  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
130
131  ### Alternative Row-Partitioning Schemes
132
133  In addition to `row_splits`, ragged tensors provide support for five other
134  row-partitioning schemes:
135
136    * `row_lengths`: a vector with shape `[nrows]`, which specifies the length
137      of each row.
138
139    * `value_rowids` and `nrows`: `value_rowids` is a vector with shape
140      `[nvals]`, corresponding one-to-one with `values`, which specifies
141      each value's row index.  In particular, the row `rt[row]` consists of the
142      values `rt.values[j]` where `value_rowids[j]==row`.  `nrows` is an
143      integer scalar that specifies the number of rows in the
144      `RaggedTensor`. (`nrows` is used to indicate trailing empty rows.)
145
146    * `row_starts`: a vector with shape `[nrows]`, which specifies the start
147      offset of each row.  Equivalent to `row_splits[:-1]`.
148
149    * `row_limits`: a vector with shape `[nrows]`, which specifies the stop
150      offset of each row.  Equivalent to `row_splits[1:]`.
151
152    * `uniform_row_length`: A scalar tensor, specifying the length of every
153      row.  This row-partitioning scheme may only be used if all rows have
154      the same length.
155
156  Example: The following ragged tensors are equivalent, and all represent the
157  nested list `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]`.
158
159  >>> values = [3, 1, 4, 1, 5, 9, 2, 6]
160  >>> rt1 = RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8])
161  >>> rt2 = RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0])
162  >>> rt3 = RaggedTensor.from_value_rowids(
163  ...     values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
164  >>> rt4 = RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
165  >>> rt5 = RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
166
167  ### Multiple Ragged Dimensions
168
169  `RaggedTensor`s with multiple ragged dimensions can be defined by using
170  a nested `RaggedTensor` for the `values` tensor.  Each nested `RaggedTensor`
171  adds a single ragged dimension.
172
173  >>> inner_rt = RaggedTensor.from_row_splits(  # =rt1 from above
174  ...     values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
175  >>> outer_rt = RaggedTensor.from_row_splits(
176  ...     values=inner_rt, row_splits=[0, 3, 3, 5])
177  >>> print(outer_rt.to_list())
178  [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
179  >>> print(outer_rt.ragged_rank)
180  2
181
182  The factory function `RaggedTensor.from_nested_row_splits` may be used to
183  construct a `RaggedTensor` with multiple ragged dimensions directly, by
184  providing a list of `row_splits` tensors:
185
186  >>> RaggedTensor.from_nested_row_splits(
187  ...     flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
188  ...     nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])).to_list()
189  [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
190
191  ### Uniform Inner Dimensions
192
193  `RaggedTensor`s with uniform inner dimensions can be defined
194  by using a multidimensional `Tensor` for `values`.
195
196  >>> rt = RaggedTensor.from_row_splits(values=tf.ones([5, 3], tf.int32),
197  ...                                   row_splits=[0, 2, 5])
198  >>> print(rt.to_list())
199  [[[1, 1, 1], [1, 1, 1]],
200   [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]
201  >>> print(rt.shape)
202  (2, None, 3)
203
204  ### Uniform Outer Dimensions
205
206  `RaggedTensor`s with uniform outer dimensions can be defined by using
207  one or more `RaggedTensor` with a `uniform_row_length` row-partitioning
208  tensor.  For example, a `RaggedTensor` with shape `[2, 2, None]` can be
209  constructed with this method from a `RaggedTensor` values with shape
210  `[4, None]`:
211
212  >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
213  >>> print(values.shape)
214  (4, None)
215  >>> rt6 = tf.RaggedTensor.from_uniform_row_length(values, 2)
216  >>> print(rt6)
217  <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
218  >>> print(rt6.shape)
219  (2, 2, None)
220
221  Note that `rt6` only contains one ragged dimension (the innermost
222  dimension). In contrast, if `from_row_splits` is used to construct a similar
223  `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
224
225  >>> rt7 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4])
226  >>> print(rt7.shape)
227  (2, None, None)
228
229  Uniform and ragged outer dimensions may be interleaved, meaning that a
230  tensor with any combination of ragged and uniform dimensions may be created.
231  For example, a RaggedTensor `t4` with shape `[3, None, 4, 8, None, 2]` could
232  be constructed as follows:
233
234  ```python
235  t0 = tf.zeros([1000, 2])                           # Shape:         [1000, 2]
236  t1 = RaggedTensor.from_row_lengths(t0, [...])      #           [160, None, 2]
237  t2 = RaggedTensor.from_uniform_row_length(t1, 8)   #         [20, 8, None, 2]
238  t3 = RaggedTensor.from_uniform_row_length(t2, 4)   #       [5, 4, 8, None, 2]
239  t4 = RaggedTensor.from_row_lengths(t3, [...])      # [3, None, 4, 8, None, 2]
240  ```
241
242  """
243
244  #=============================================================================
245  # Constructor (private)
246  #=============================================================================
247  @doc_controls.do_not_generate_docs
248  def __init__(self, values, row_partition, internal=False):
249    """Creates a `RaggedTensor` with a specified partitioning for `values`.
250
251    This constructor is private -- please use one of the following ops to
252    build `RaggedTensor`s:
253
254      * `tf.RaggedTensor.from_row_lengths`
255      * `tf.RaggedTensor.from_value_rowids`
256      * `tf.RaggedTensor.from_row_splits`
257      * `tf.RaggedTensor.from_row_starts`
258      * `tf.RaggedTensor.from_row_limits`
259      * `tf.RaggedTensor.from_nested_row_splits`
260      * `tf.RaggedTensor.from_nested_row_lengths`
261      * `tf.RaggedTensor.from_nested_value_rowids`
262
263    Args:
264      values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`.
265      row_partition: A `RowPartition` object, representing the arrangement of
266        the lists at the top level.
267      internal: True if the constructor is being called by one of the factory
268        methods.  If false, an exception will be raised.
269
270    Raises:
271      ValueError: If internal = False. Note that this method is intended only
272                 for internal use.
273      TypeError: If values is not a `RaggedTensor` or `Tensor`, or
274                 row_partition is not a `RowPartition`.
275    """
276
277    if not internal:
278      raise ValueError("RaggedTensor constructor is private; please use one "
279                       "of the factory methods instead (e.g., "
280                       "RaggedTensor.from_row_lengths())")
281    _assert_is_supported_ragged_values_type(values)
282    if not isinstance(row_partition, RowPartition):
283      raise TypeError("row_partition must be a RowPartition, got %r" %
284                      row_partition)
285
286    # Validate shapes.
287    values.shape.with_rank_at_least(1)
288    if isinstance(values, RaggedTensor):
289      # pylint: disable=protected-access
290      assert row_partition.dtype == values._row_partition.dtype
291
292    self._values = values
293    self._row_partition = row_partition
294
295  #=============================================================================
296  # Factory Methods
297  #=============================================================================
298
299  @classmethod
300  def _from_row_partition(cls, values, row_partition, validate=True):
301    """Creates a `RaggedTensor` with a row partition.
302
303    This is used as a way for RaggedTensors to share row partitions.
304
305    The outer dimension of values must be equal to `partition.nvals()`.
306
307    Args:
308      values: A potentially ragged tensor.
309      row_partition: a `RowPartition`: can be shared between tensors.
310      validate: If true, then use assertions to check that the arguments form a
311        valid `RaggedTensor`.
312
313    Returns:
314      A `RaggedTensor`.  `result.rank = values.rank + 1`.
315      `result.ragged_rank = values.ragged_rank + 1`.
316
317    Raises:
318      ValueError: If partition.nvals() != _nrows(values)
319    """
320    if not isinstance(row_partition, RowPartition):
321      raise TypeError("row_partition must be a RowPartition")
322    if not isinstance(validate, bool):
323      raise TypeError("validate must have type bool")
324    values, row_partition = cls._convert_values_and_partition(
325        values, row_partition, "partition")
326    if row_partition.has_precomputed_value_rowids():
327      value_rowids_shape = row_partition.value_rowids().shape
328      values.shape[:1].assert_is_compatible_with(value_rowids_shape)
329    if validate:
330      msg = "Arguments to _from_row_partition do not form a valid RaggedTensor"
331      nvals = _nrows(values, row_partition.dtype)
332      checks = [
333          check_ops.assert_equal(
334              row_partition.nvals(out_type=row_partition.dtype),
335              nvals,
336              message=msg),
337      ]
338      if not isinstance(values, RaggedTensor):
339        checks.append(check_ops.assert_rank_at_least(values, 1))
340      row_partition = row_partition.with_dependencies(checks)
341    return cls(
342        values=values,
343        internal=True,
344        row_partition=row_partition)
345
346  @classmethod
347  @dispatch.add_dispatch_support
348  def from_value_rowids(cls,
349                        values,
350                        value_rowids,
351                        nrows=None,
352                        name=None,
353                        validate=True):
354    """Creates a `RaggedTensor` with rows partitioned by `value_rowids`.
355
356    The returned `RaggedTensor` corresponds with the python list defined by:
357
358    ```python
359    result = [[values[i] for i in range(len(values)) if value_rowids[i] == row]
360              for row in range(nrows)]
361    ```
362
363    Args:
364      values: A potentially ragged tensor with shape `[nvals, ...]`.
365      value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
366        one-to-one with `values`, and specifies each value's row index.  Must be
367        nonnegative, and must be sorted in ascending order.
368      nrows: An integer scalar specifying the number of rows.  This should be
369        specified if the `RaggedTensor` may containing empty training rows. Must
370        be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty).
371        Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty).
372      name: A name prefix for the RaggedTensor (optional).
373      validate: If true, then use assertions to check that the arguments form
374        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
375        since they must be checked for each tensor value.
376
377    Returns:
378      A `RaggedTensor`.  `result.rank = values.rank + 1`.
379      `result.ragged_rank = values.ragged_rank + 1`.
380
381    Raises:
382      ValueError: If `nrows` is incompatible with `value_rowids`.
383
384    #### Example:
385
386    >>> print(tf.RaggedTensor.from_value_rowids(
387    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
388    ...     value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
389    ...     nrows=5))
390    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
391
392    """
393    if not isinstance(validate, bool):
394      raise TypeError("validate must have type bool")
395
396    with ops.name_scope(name, "RaggedFromValueRowIds",
397                        [values, value_rowids, nrows]):
398      row_partition = RowPartition.from_value_rowids(
399          value_rowids=value_rowids,
400          nrows=nrows,
401          validate=validate,
402          preferred_dtype=_get_optional_partition_dtype(values))
403      return cls._from_row_partition(values, row_partition, validate=validate)
404
405  @classmethod
406  @dispatch.add_dispatch_support
407  def from_row_splits(cls, values, row_splits, name=None, validate=True):
408    """Creates a `RaggedTensor` with rows partitioned by `row_splits`.
409
410    The returned `RaggedTensor` corresponds with the python list defined by:
411
412    ```python
413    result = [values[row_splits[i]:row_splits[i + 1]]
414              for i in range(len(row_splits) - 1)]
415    ```
416
417    Args:
418      values: A potentially ragged tensor with shape `[nvals, ...]`.
419      row_splits: A 1-D integer tensor with shape `[nrows+1]`.  Must not be
420        empty, and must be sorted in ascending order.  `row_splits[0]` must be
421        zero and `row_splits[-1]` must be `nvals`.
422      name: A name prefix for the RaggedTensor (optional).
423      validate: If true, then use assertions to check that the arguments form
424        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
425        since they must be checked for each tensor value.
426
427    Returns:
428      A `RaggedTensor`.  `result.rank = values.rank + 1`.
429      `result.ragged_rank = values.ragged_rank + 1`.
430
431    Raises:
432      ValueError: If `row_splits` is an empty list.
433
434    #### Example:
435
436    >>> print(tf.RaggedTensor.from_row_splits(
437    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
438    ...     row_splits=[0, 4, 4, 7, 8, 8]))
439    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
440
441    """
442    if not isinstance(validate, bool):
443      raise TypeError("validate must have type bool")
444
445    with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
446      row_partition = RowPartition.from_row_splits(
447          row_splits=row_splits,
448          validate=validate,
449          preferred_dtype=_get_optional_partition_dtype(values))
450      return cls._from_row_partition(values, row_partition, validate=validate)
451
452  @classmethod
453  @dispatch.add_dispatch_support
454  def from_row_lengths(cls, values, row_lengths, name=None, validate=True):
455    """Creates a `RaggedTensor` with rows partitioned by `row_lengths`.
456
457    The returned `RaggedTensor` corresponds with the python list defined by:
458
459    ```python
460    result = [[values.pop(0) for i in range(length)]
461              for length in row_lengths]
462    ```
463
464    Args:
465      values: A potentially ragged tensor with shape `[nvals, ...]`.
466      row_lengths: A 1-D integer tensor with shape `[nrows]`.  Must be
467        nonnegative.  `sum(row_lengths)` must be `nvals`.
468      name: A name prefix for the RaggedTensor (optional).
469      validate: If true, then use assertions to check that the arguments form
470        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
471        since they must be checked for each tensor value.
472
473    Returns:
474      A `RaggedTensor`.  `result.rank = values.rank + 1`.
475      `result.ragged_rank = values.ragged_rank + 1`.
476
477    #### Example:
478
479    >>> print(tf.RaggedTensor.from_row_lengths(
480    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
481    ...     row_lengths=[4, 0, 3, 1, 0]))
482    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
483
484    """
485    if not isinstance(validate, bool):
486      raise TypeError("validate must have type bool")
487
488    with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]):
489      row_partition = RowPartition.from_row_lengths(
490          row_lengths=row_lengths,
491          validate=validate,
492          preferred_dtype=_get_optional_partition_dtype(values))
493      return cls._from_row_partition(values, row_partition, validate=validate)
494
495  @classmethod
496  @dispatch.add_dispatch_support
497  def from_row_starts(cls, values, row_starts, name=None, validate=True):
498    """Creates a `RaggedTensor` with rows partitioned by `row_starts`.
499
500    Equivalent to: `from_row_splits(values, concat([row_starts, nvals]))`.
501
502    Args:
503      values: A potentially ragged tensor with shape `[nvals, ...]`.
504      row_starts: A 1-D integer tensor with shape `[nrows]`.  Must be
505        nonnegative and sorted in ascending order.  If `nrows>0`, then
506        `row_starts[0]` must be zero.
507      name: A name prefix for the RaggedTensor (optional).
508      validate: If true, then use assertions to check that the arguments form
509        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
510        since they must be checked for each tensor value.
511
512    Returns:
513      A `RaggedTensor`.  `result.rank = values.rank + 1`.
514      `result.ragged_rank = values.ragged_rank + 1`.
515
516    #### Example:
517
518    >>> print(tf.RaggedTensor.from_row_starts(
519    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
520    ...     row_starts=[0, 4, 4, 7, 8]))
521    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
522
523    """
524    if not isinstance(validate, bool):
525      raise TypeError("validate must have type bool")
526    with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]):
527      values = _convert_to_ragged_tensor_values(values)
528      row_partition = RowPartition.from_row_starts(
529          row_starts=row_starts,
530          nvals=_nrows(values),
531          validate=validate,
532          preferred_dtype=_get_optional_partition_dtype(values))
533      return cls._from_row_partition(values, row_partition, validate=validate)
534
535  @classmethod
536  @dispatch.add_dispatch_support
537  def from_row_limits(cls, values, row_limits, name=None, validate=True):
538    """Creates a `RaggedTensor` with rows partitioned by `row_limits`.
539
540    Equivalent to: `from_row_splits(values, concat([0, row_limits]))`.
541
542    Args:
543      values: A potentially ragged tensor with shape `[nvals, ...]`.
544      row_limits: A 1-D integer tensor with shape `[nrows]`.  Must be sorted in
545        ascending order.  If `nrows>0`, then `row_limits[-1]` must be `nvals`.
546      name: A name prefix for the RaggedTensor (optional).
547      validate: If true, then use assertions to check that the arguments form
548        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
549        since they must be checked for each tensor value.
550
551    Returns:
552      A `RaggedTensor`.  `result.rank = values.rank + 1`.
553      `result.ragged_rank = values.ragged_rank + 1`.
554
555    #### Example:
556
557    >>> print(tf.RaggedTensor.from_row_limits(
558    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
559    ...     row_limits=[4, 4, 7, 8, 8]))
560    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
561
562    """
563    if not isinstance(validate, bool):
564      raise TypeError("validate must have type bool")
565    with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]):
566      row_partition = RowPartition.from_row_limits(
567          row_limits=row_limits,
568          validate=validate,
569          preferred_dtype=_get_optional_partition_dtype(values))
570      return cls._from_row_partition(values, row_partition, validate=validate)
571
572  @classmethod
573  @dispatch.add_dispatch_support
574  def from_uniform_row_length(cls,
575                              values,
576                              uniform_row_length,
577                              nrows=None,
578                              validate=True,
579                              name=None):
580    """Creates a `RaggedTensor` with rows partitioned by `uniform_row_length`.
581
582    This method can be used to create `RaggedTensor`s with multiple uniform
583    outer dimensions.  For example, a `RaggedTensor` with shape `[2, 2, None]`
584    can be constructed with this method from a `RaggedTensor` values with shape
585    `[4, None]`:
586
587    >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
588    >>> print(values.shape)
589    (4, None)
590    >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
591    >>> print(rt1)
592    <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
593    >>> print(rt1.shape)
594    (2, 2, None)
595
596    Note that `rt1` only contains one ragged dimension (the innermost
597    dimension). In contrast, if `from_row_splits` is used to construct a similar
598    `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
599
600    >>> rt2 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4])
601    >>> print(rt2.shape)
602    (2, None, None)
603
604    Args:
605      values: A potentially ragged tensor with shape `[nvals, ...]`.
606      uniform_row_length: A scalar integer tensor.  Must be nonnegative. The
607        size of the outer axis of `values` must be evenly divisible by
608        `uniform_row_length`.
609      nrows: The number of rows in the constructed RaggedTensor.  If not
610        specified, then it defaults to `nvals/uniform_row_length` (or `0` if
611        `uniform_row_length==0`).  `nrows` only needs to be specified if
612        `uniform_row_length` might be zero.  `uniform_row_length*nrows` must
613        be `nvals`.
614      validate: If true, then use assertions to check that the arguments form
615        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
616        since they must be checked for each tensor value.
617      name: A name prefix for the RaggedTensor (optional).
618
619    Returns:
620      A `RaggedTensor` that corresponds with the python list defined by:
621
622      ```python
623      result = [[values.pop(0) for i in range(uniform_row_length)]
624                for _ in range(nrows)]
625      ```
626
627      `result.rank = values.rank + 1`.
628      `result.ragged_rank = values.ragged_rank + 1`.
629    """
630    if not isinstance(validate, bool):
631      raise TypeError("validate must have type bool")
632    with ops.name_scope(name, "RaggedFromUniformRowLength",
633                        [values, uniform_row_length, nrows]):
634      values = _convert_to_ragged_tensor_values(values)
635      uniform_row_length = _convert_row_partition(
636          uniform_row_length, "UniformRowLength",
637          _get_optional_partition_dtype(values))
638      nvals = _nvals_uniform_row_length(values, uniform_row_length)
639      row_partition = RowPartition.from_uniform_row_length(
640          uniform_row_length=uniform_row_length,
641          nvals=nvals,
642          nrows=nrows,
643          validate=validate,
644          preferred_dtype=_get_optional_partition_dtype(values))
645      return cls._from_row_partition(values, row_partition, validate=validate)
646
647  @classmethod
648  @dispatch.add_dispatch_support
649  def from_nested_value_rowids(cls,
650                               flat_values,
651                               nested_value_rowids,
652                               nested_nrows=None,
653                               name=None,
654                               validate=True):
655    """Creates a `RaggedTensor` from a nested list of `value_rowids` tensors.
656
657    Equivalent to:
658
659    ```python
660    result = flat_values
661    for (rowids, nrows) in reversed(zip(nested_value_rowids, nested_nrows)):
662      result = from_value_rowids(result, rowids, nrows)
663    ```
664
665    Args:
666      flat_values: A potentially ragged tensor.
667      nested_value_rowids: A list of 1-D integer tensors.  The `i`th tensor is
668        used as the `value_rowids` for the `i`th ragged dimension.
669      nested_nrows: A list of integer scalars.  The `i`th scalar is used as the
670        `nrows` for the `i`th ragged dimension.
671      name: A name prefix for the RaggedTensor (optional).
672      validate: If true, then use assertions to check that the arguments form
673        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
674        since they must be checked for each tensor value.
675
676    Returns:
677      A `RaggedTensor` (or `flat_values` if `nested_value_rowids` is empty).
678
679    Raises:
680      ValueError: If `len(nested_values_rowids) != len(nested_nrows)`.
681    """
682    if not isinstance(validate, bool):
683      raise TypeError("validate must have type bool")
684    if isinstance(nested_value_rowids, ops.Tensor):
685      raise TypeError("nested_value_rowids must be a list of Tensors")
686    if nested_nrows is None:
687      nested_nrows = [None] * len(nested_value_rowids)
688    else:
689      if isinstance(nested_nrows, ops.Tensor):
690        raise TypeError("nested_nrows must be a list of Tensors")
691      if len(nested_nrows) != len(nested_value_rowids):
692        raise ValueError("nested_nrows must have the same length as "
693                         "nested_value_rowids")
694
695    with ops.name_scope(name, "RaggedFromNestedValueRowIds", [flat_values] +
696                        list(nested_value_rowids) + list(nested_nrows)):
697      result = flat_values
698      for value_rowids, nrows in reversed(
699          list(zip(nested_value_rowids, nested_nrows))):
700        result = cls.from_value_rowids(
701            result, value_rowids, nrows, validate=validate)
702      return result
703
704  @classmethod
705  @dispatch.add_dispatch_support
706  def from_nested_row_splits(cls,
707                             flat_values,
708                             nested_row_splits,
709                             name=None,
710                             validate=True):
711    """Creates a `RaggedTensor` from a nested list of `row_splits` tensors.
712
713    Equivalent to:
714
715    ```python
716    result = flat_values
717    for row_splits in reversed(nested_row_splits):
718      result = from_row_splits(result, row_splits)
719    ```
720
721    Args:
722      flat_values: A potentially ragged tensor.
723      nested_row_splits: A list of 1-D integer tensors.  The `i`th tensor is
724        used as the `row_splits` for the `i`th ragged dimension.
725      name: A name prefix for the RaggedTensor (optional).
726      validate: If true, then use assertions to check that the arguments form
727        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
728        since they must be checked for each tensor value.
729
730    Returns:
731      A `RaggedTensor` (or `flat_values` if `nested_row_splits` is empty).
732    """
733    if not isinstance(validate, bool):
734      raise TypeError("validate must have type bool")
735    if isinstance(nested_row_splits, ops.Tensor):
736      raise TypeError("nested_row_splits must be a list of Tensors")
737    with ops.name_scope(name, "RaggedFromNestedRowSplits",
738                        [flat_values] + list(nested_row_splits)):
739      result = flat_values
740      for splits in reversed(nested_row_splits):
741        result = cls.from_row_splits(result, splits, validate=validate)
742      return result
743
744  @classmethod
745  @dispatch.add_dispatch_support
746  def from_nested_row_lengths(cls,
747                              flat_values,
748                              nested_row_lengths,
749                              name=None,
750                              validate=True):
751    """Creates a `RaggedTensor` from a nested list of `row_lengths` tensors.
752
753    Equivalent to:
754
755    ```python
756    result = flat_values
757    for row_lengths in reversed(nested_row_lengths):
758      result = from_row_lengths(result, row_lengths)
759    ```
760
761    Args:
762      flat_values: A potentially ragged tensor.
763      nested_row_lengths: A list of 1-D integer tensors.  The `i`th tensor is
764        used as the `row_lengths` for the `i`th ragged dimension.
765      name: A name prefix for the RaggedTensor (optional).
766      validate: If true, then use assertions to check that the arguments form
767        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
768        since they must be checked for each tensor value.
769
770    Returns:
771      A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty).
772    """
773    if not isinstance(validate, bool):
774      raise TypeError("validate must have type bool")
775    if isinstance(nested_row_lengths, ops.Tensor):
776      raise TypeError("nested_row_lengths must be a list of Tensors")
777    with ops.name_scope(name, "RaggedFromNestedRowlengths",
778                        [flat_values] + list(nested_row_lengths)):
779      result = flat_values
780      for lengths in reversed(nested_row_lengths):
781        result = cls.from_row_lengths(result, lengths, validate=validate)
782      return result
783
784  @classmethod
785  def _convert_values_and_partition(cls, values, row_partition, name):
786    """Converts `values` and `partition` to Tensors.
787
788    If `values` is a `RaggedTensor`, then converts `values` and `partition`
789    to have compatible row-partitioning dtypes.  In particular, if any of the
790    row partitioning tensors are `int64`, then all of the other row
791    partitioning tensors wil be cast to `int64` (if auto_cast_partition_dtype()
792    is true) or an error will be raised (if auto_cast_partition_dtype() is
793    false).
794
795    Args:
796      values: The `values` for the `RaggedTensor` being constructed.
797      row_partition: A RowPartition object for the `RaggedTensor` being
798        constructed.
799      name: The name of the RowPartition object.
800
801    Returns:
802      A tuple (values, partition).
803    """
804    if not isinstance(row_partition, RowPartition):
805      raise ValueError("partition must be a RowPartition")
806    if isinstance(values, RaggedTensor):
807      # pylint: disable=protected-access
808      if values._row_partition.dtype != row_partition.dtype:
809        if not ragged_config.auto_cast_partition_dtype():
810          # pylint: disable=protected-access
811          raise ValueError(
812              "dtype mismatch: %s (%s) vs values.partition (%s)" %
813              (name, row_partition.dtype, values._row_partition.dtype))
814        values = values.with_row_splits_dtype(row_partition.dtype)
815    else:
816      values = _convert_to_ragged_tensor_values(values)
817
818    return (values, row_partition)
819
820  #=============================================================================
821  # Accessors
822  #=============================================================================
823
824  @property
825  def dtype(self):
826    """The `DType` of values in this tensor."""
827    return self._values.dtype
828
829  @property
830  def shape(self):
831    """The statically known shape of this ragged tensor.
832
833    Returns:
834      A `TensorShape` containing the statically known shape of this ragged
835      tensor.  Ragged dimensions have a size of `None`.
836
837    Examples:
838
839    >>> tf.ragged.constant([[0], [1, 2]]).shape
840    TensorShape([2, None])
841
842    >>> tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape
843    TensorShape([2, None, 2])
844
845    """
846    nrows = self._row_partition.static_nrows
847    ncols = self._row_partition.static_uniform_row_length
848    value_shape = self._values.shape[1:]
849    return tensor_shape.TensorShape([nrows, ncols]).concatenate(value_shape)
850
851  def get_shape(self):
852    """The statically known shape of this ragged tensor.
853
854    Returns:
855      A `TensorShape` containing the statically known shape of this ragged
856      tensor.  Ragged dimensions have a size of `None`.
857
858    Alias for `shape` property.
859
860    Examples:
861
862    >>> tf.ragged.constant([[0], [1, 2]]).get_shape()
863    TensorShape([2, None])
864
865    >>> tf.ragged.constant(
866    ...    [[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).get_shape()
867    TensorShape([2, None, 2])
868
869    """
870    return self.shape
871
872  @property
873  def ragged_rank(self):
874    """The number of times the RaggedTensor's flat_values is partitioned.
875
876    Examples:
877
878    >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
879    >>> values.ragged_rank
880    1
881
882    >>> rt = tf.RaggedTensor.from_uniform_row_length(values, 2)
883    >>> rt.ragged_rank
884    2
885
886    Returns:
887      A Python `int` indicating the number of times the underlying `flat_values`
888      Tensor has been partitioned to add a new dimension.
889      I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
890    """
891    values_is_ragged = isinstance(self._values, RaggedTensor)
892    return self._values.ragged_rank + 1 if values_is_ragged else 1
893
894  @property
895  def values(self):
896    """The concatenated rows for this ragged tensor.
897
898    `rt.values` is a potentially ragged tensor formed by flattening the two
899    outermost dimensions of `rt` into a single dimension.
900
901    `rt.values.shape = [nvals] + rt.shape[2:]` (where `nvals` is the
902    number of items in the outer two dimensions of `rt`).
903
904    `rt.ragged_rank = self.ragged_rank - 1`
905
906    Returns:
907      A potentially ragged tensor.
908
909    #### Example:
910
911    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
912    >>> print(rt.values)
913    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
914
915    """
916    return self._values
917
918  @property
919  def _nested_row_partitions(self):
920    """Returns the row partitions for this `RaggedTensor`."""
921    partitions = [self._row_partition]
922    rt_values = self.values
923    while isinstance(rt_values, RaggedTensor):
924      # pylint: disable=protected-access
925      partitions.append(rt_values._row_partition)
926      rt_values = rt_values.values
927    return tuple(partitions)
928
929  @property
930  def row_splits(self):
931    """The row-split indices for this ragged tensor's `values`.
932
933    `rt.row_splits` specifies where the values for each row begin and end in
934    `rt.values`.  In particular, the values for row `rt[i]` are stored in
935    the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
936
937    Returns:
938      A 1-D integer `Tensor` with shape `[self.nrows+1]`.
939      The returned tensor is non-empty, and is sorted in ascending order.
940      `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to
941      `self.values.shape[0]`.
942
943    #### Example:
944
945    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
946    >>> print(rt.row_splits)  # indices of row splits in rt.values
947    tf.Tensor([0 4 4 7 8 8], shape=(6,), dtype=int64)
948
949    """
950    return self._row_partition.row_splits()
951
952  @property
953  def uniform_row_length(self):
954    """The length of each row in this ragged tensor, or None if rows are ragged.
955
956    >>> rt1 = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
957    >>> print(rt1.uniform_row_length)  # rows are ragged.
958    None
959
960    >>> rt2 = tf.RaggedTensor.from_uniform_row_length(
961    ...     values=rt1, uniform_row_length=2)
962    >>> print(rt2)
963    <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
964    >>> print(rt2.uniform_row_length)  # rows are not ragged (all have size 2).
965    tf.Tensor(2, shape=(), dtype=int64)
966
967    A RaggedTensor's rows are only considered to be uniform (i.e. non-ragged)
968    if it can be determined statically (at graph construction time) that the
969    rows all have the same length.
970
971    Returns:
972      A scalar integer `Tensor`, specifying the length of every row in this
973      ragged tensor (for ragged tensors whose rows are uniform); or `None`
974      (for ragged tensors whose rows are ragged).
975    """
976    return self._row_partition.uniform_row_length()
977
978  @property
979  def flat_values(self):
980    """The innermost `values` tensor for this ragged tensor.
981
982    Concretely, if `rt.values` is a `Tensor`, then `rt.flat_values` is
983    `rt.values`; otherwise, `rt.flat_values` is `rt.values.flat_values`.
984
985    Conceptually, `flat_values` is the tensor formed by flattening the
986    outermost dimension and all of the ragged dimensions into a single
987    dimension.
988
989    `rt.flat_values.shape = [nvals] + rt.shape[rt.ragged_rank + 1:]`
990    (where `nvals` is the number of items in the flattened dimensions).
991
992    Returns:
993      A `Tensor`.
994
995    #### Example:
996
997    >>> rt = tf.ragged.constant([[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
998    >>> print(rt.flat_values)
999    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1000
1001    """
1002    rt_values = self.values
1003    while isinstance(rt_values, RaggedTensor):
1004      rt_values = rt_values.values
1005    return rt_values
1006
1007  @property
1008  def nested_row_splits(self):
1009    """A tuple containing the row_splits for all ragged dimensions.
1010
1011    `rt.nested_row_splits` is a tuple containing the `row_splits` tensors for
1012    all ragged dimensions in `rt`, ordered from outermost to innermost.  In
1013    particular, `rt.nested_row_splits = (rt.row_splits,) + value_splits` where:
1014
1015        * `value_splits = ()` if `rt.values` is a `Tensor`.
1016        * `value_splits = rt.values.nested_row_splits` otherwise.
1017
1018    Returns:
1019      A `tuple` of 1-D integer `Tensor`s.
1020
1021    #### Example:
1022
1023    >>> rt = tf.ragged.constant(
1024    ...     [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
1025    >>> for i, splits in enumerate(rt.nested_row_splits):
1026    ...   print('Splits for dimension %d: %s' % (i+1, splits.numpy()))
1027    Splits for dimension 1: [0 3]
1028    Splits for dimension 2: [0 3 3 5]
1029    Splits for dimension 3: [0 4 4 7 8 8]
1030
1031    """
1032    rt_nested_splits = [self.row_splits]
1033    rt_values = self.values
1034    while isinstance(rt_values, RaggedTensor):
1035      rt_nested_splits.append(rt_values.row_splits)
1036      rt_values = rt_values.values
1037    return tuple(rt_nested_splits)
1038
1039  def value_rowids(self, name=None):
1040    """Returns the row indices for the `values` in this ragged tensor.
1041
1042    `rt.value_rowids()` corresponds one-to-one with the outermost dimension of
1043    `rt.values`, and specifies the row containing each value.  In particular,
1044    the row `rt[row]` consists of the values `rt.values[j]` where
1045    `rt.value_rowids()[j] == row`.
1046
1047    Args:
1048      name: A name prefix for the returned tensor (optional).
1049
1050    Returns:
1051      A 1-D integer `Tensor` with shape `self.values.shape[:1]`.
1052      The returned tensor is nonnegative, and is sorted in ascending order.
1053
1054    #### Example:
1055
1056    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1057    >>> print(rt.values)
1058    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1059    >>> print(rt.value_rowids())  # corresponds 1:1 with rt.values
1060    tf.Tensor([0 0 0 0 2 2 2 3], shape=(8,), dtype=int64)
1061
1062    """
1063    with ops.name_scope(name, "RaggedValueRowIds", [self]):
1064      return self._row_partition.value_rowids()
1065
1066  def nested_value_rowids(self, name=None):
1067    """Returns a tuple containing the value_rowids for all ragged dimensions.
1068
1069    `rt.nested_value_rowids` is a tuple containing the `value_rowids` tensors
1070    for
1071    all ragged dimensions in `rt`, ordered from outermost to innermost.  In
1072    particular, `rt.nested_value_rowids = (rt.value_rowids(),) + value_ids`
1073    where:
1074
1075    * `value_ids = ()` if `rt.values` is a `Tensor`.
1076    * `value_ids = rt.values.nested_value_rowids` otherwise.
1077
1078    Args:
1079      name: A name prefix for the returned tensors (optional).
1080
1081    Returns:
1082      A `tuple` of 1-D integer `Tensor`s.
1083
1084    #### Example:
1085
1086    >>> rt = tf.ragged.constant(
1087    ...     [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
1088    >>> for i, ids in enumerate(rt.nested_value_rowids()):
1089    ...   print('row ids for dimension %d: %s' % (i+1, ids.numpy()))
1090    row ids for dimension 1: [0 0 0]
1091    row ids for dimension 2: [0 0 0 2 2]
1092    row ids for dimension 3: [0 0 0 0 2 2 2 3]
1093
1094    """
1095    with ops.name_scope(name, "RaggedNestedValueRowIds", [self]):
1096      rt_nested_ids = [self.value_rowids()]
1097      rt_values = self.values
1098      while isinstance(rt_values, RaggedTensor):
1099        rt_nested_ids.append(rt_values.value_rowids())
1100        rt_values = rt_values.values
1101      return tuple(rt_nested_ids)
1102
1103  def nrows(self, out_type=None, name=None):
1104    """Returns the number of rows in this ragged tensor.
1105
1106    I.e., the size of the outermost dimension of the tensor.
1107
1108    Args:
1109      out_type: `dtype` for the returned tensor.  Defaults to
1110        `self.row_splits.dtype`.
1111      name: A name prefix for the returned tensor (optional).
1112
1113    Returns:
1114      A scalar `Tensor` with dtype `out_type`.
1115
1116    #### Example:
1117
1118    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1119    >>> print(rt.nrows())  # rt has 5 rows.
1120    tf.Tensor(5, shape=(), dtype=int64)
1121
1122    """
1123    with ops.name_scope(name, "RaggedNRows", [self]):
1124      return self._row_partition.nrows(out_type=out_type)
1125
1126  def row_starts(self, name=None):
1127    """Returns the start indices for rows in this ragged tensor.
1128
1129    These indices specify where the values for each row begin in
1130    `self.values`.  `rt.row_starts()` is equal to `rt.row_splits[:-1]`.
1131
1132    Args:
1133      name: A name prefix for the returned tensor (optional).
1134
1135    Returns:
1136      A 1-D integer Tensor with shape `[nrows]`.
1137      The returned tensor is nonnegative, and is sorted in ascending order.
1138
1139    #### Example:
1140
1141    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1142    >>> print(rt.values)
1143    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1144    >>> print(rt.row_starts())  # indices of row starts in rt.values
1145    tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64)
1146
1147    """
1148    with ops.name_scope(name, "RaggedRowStarts", [self]):
1149      return self._row_partition.row_starts()
1150
1151  def row_limits(self, name=None):
1152    """Returns the limit indices for rows in this ragged tensor.
1153
1154    These indices specify where the values for each row end in
1155    `self.values`.  `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`.
1156
1157    Args:
1158      name: A name prefix for the returned tensor (optional).
1159
1160    Returns:
1161      A 1-D integer Tensor with shape `[nrows]`.
1162      The returned tensor is nonnegative, and is sorted in ascending order.
1163
1164    #### Example:
1165
1166    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1167    >>> print(rt.values)
1168    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1169    >>> print(rt.row_limits())  # indices of row limits in rt.values
1170    tf.Tensor([4 4 7 8 8], shape=(5,), dtype=int64)
1171
1172    """
1173    with ops.name_scope(name, "RaggedRowLimits", [self]):
1174      return self._row_partition.row_limits()
1175
1176  def row_lengths(self, axis=1, name=None):
1177    """Returns the lengths of the rows in this ragged tensor.
1178
1179    `rt.row_lengths()[i]` indicates the number of values in the
1180    `i`th row of `rt`.
1181
1182    Args:
1183      axis: An integer constant indicating the axis whose row lengths should be
1184        returned.
1185      name: A name prefix for the returned tensor (optional).
1186
1187    Returns:
1188      A potentially ragged integer Tensor with shape `self.shape[:axis]`.
1189
1190    Raises:
1191      ValueError: If `axis` is out of bounds.
1192
1193    #### Example:
1194
1195    >>> rt = tf.ragged.constant(
1196    ...     [[[3, 1, 4], [1]], [], [[5, 9], [2]], [[6]], []])
1197    >>> print(rt.row_lengths())  # lengths of rows in rt
1198    tf.Tensor([2 0 2 1 0], shape=(5,), dtype=int64)
1199    >>> print(rt.row_lengths(axis=2))  # lengths of axis=2 rows.
1200    <tf.RaggedTensor [[3, 1], [], [2, 1], [1], []]>
1201
1202    """
1203    if axis == 0:
1204      return self._row_partition.nrows()
1205
1206    if axis == 1:
1207      return self._row_partition.row_lengths()
1208
1209    with ops.name_scope(name, "RaggedRowLengths", [self]):
1210      axis = array_ops.get_positive_axis(
1211          axis, self.shape.rank, ndims_name="rank(self)")
1212      if axis == 0:
1213        return self.nrows()
1214      elif axis == 1:
1215        splits = self.row_splits
1216        return splits[1:] - splits[:-1]
1217      elif isinstance(self.values, RaggedTensor):
1218        return self.with_values(self.values.row_lengths(axis - 1))
1219      else:
1220        shape = array_ops.shape(self.values, out_type=self._row_partition.dtype)
1221        return self.with_values(
1222            array_ops.ones(shape[:axis - 1], self._row_partition.dtype) *
1223            shape[axis - 1])
1224
1225  def nested_row_lengths(self, name=None):
1226    """Returns a tuple containing the row_lengths for all ragged dimensions.
1227
1228    `rt.nested_row_lengths()` is a tuple containing the `row_lengths` tensors
1229    for all ragged dimensions in `rt`, ordered from outermost to innermost.
1230
1231    Args:
1232      name: A name prefix for the returned tensors (optional).
1233
1234    Returns:
1235      A `tuple` of 1-D integer `Tensors`.  The length of the tuple is equal to
1236      `self.ragged_rank`.
1237    """
1238    with ops.name_scope(name, "RaggedNestedRowLengths", [self]):
1239      rt_nested_row_lengths = []
1240      rt = self
1241      while isinstance(rt, RaggedTensor):
1242        rt_nested_row_lengths.append(rt.row_lengths())
1243        rt = rt.values
1244      return tuple(rt_nested_row_lengths)
1245
1246  def bounding_shape(self, axis=None, name=None, out_type=None):
1247    """Returns the tight bounding box shape for this `RaggedTensor`.
1248
1249    Args:
1250      axis: An integer scalar or vector indicating which axes to return the
1251        bounding box for.  If not specified, then the full bounding box is
1252        returned.
1253      name: A name prefix for the returned tensor (optional).
1254      out_type: `dtype` for the returned tensor.  Defaults to
1255        `self.row_splits.dtype`.
1256
1257    Returns:
1258      An integer `Tensor` (`dtype=self.row_splits.dtype`).  If `axis` is not
1259      specified, then `output` is a vector with
1260      `output.shape=[self.shape.ndims]`.  If `axis` is a scalar, then the
1261      `output` is a scalar.  If `axis` is a vector, then `output` is a vector,
1262      where `output[i]` is the bounding size for dimension `axis[i]`.
1263
1264    #### Example:
1265
1266    >>> rt = tf.ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]])
1267    >>> rt.bounding_shape().numpy()
1268    array([5, 4])
1269
1270    """
1271    if out_type is None:
1272      out_type = self._row_partition.dtype
1273    else:
1274      out_type = dtypes.as_dtype(out_type)
1275    with ops.name_scope(name, "RaggedBoundingBox", [self, axis]):
1276      nested_splits = self.nested_row_splits
1277      rt_flat_values = self.flat_values
1278
1279      # Optimized special cases for when axis=0 or axis=1:
1280      if isinstance(axis, int):
1281        if axis == 0:
1282          return array_ops.shape(nested_splits[0], out_type=out_type)[0] - 1
1283        elif axis == 1:
1284          return math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0)
1285
1286      splits_shape = array_ops.shape(self.row_splits, out_type=out_type)
1287      flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type)
1288
1289      ragged_dimensions = [splits_shape[0] - 1] + [
1290          math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0)
1291          for splits in nested_splits
1292      ]
1293      inner_dimensions = flat_values_shape[1:]
1294
1295      if out_type != self._row_partition.dtype:
1296        ragged_dimensions = [
1297            math_ops.cast(d, out_type) for d in ragged_dimensions
1298        ]
1299      bbox = array_ops.concat(
1300          [array_ops.stack(ragged_dimensions), inner_dimensions], axis=0)
1301      return bbox if axis is None else array_ops.gather(bbox, axis)
1302
1303  #=============================================================================
1304  # Transformation
1305  #=============================================================================
1306
1307  def with_values(self, new_values):
1308    """Returns a copy of `self` with `values` replaced by `new_value`.
1309
1310    Preserves cached row-partitioning tensors such as `self.cached_nrows` and
1311    `self.cached_value_rowids` if they have values.
1312
1313    Args:
1314      new_values: Potentially ragged tensor to use as the `values` for the
1315        returned `RaggedTensor`.  Must have `rank > 0`, and must have the same
1316        number of rows as `self.values`.
1317
1318    Returns:
1319      A `RaggedTensor`.  `result.rank = 1 + new_values.rank`.
1320      `result.ragged_rank = 1 + new_values.ragged_rank`
1321    """
1322    new_values = _convert_to_ragged_tensor_values(new_values)
1323    new_values.shape.with_rank_at_least(1)
1324    self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1])
1325    if (isinstance(new_values, RaggedTensor) and
1326        self._row_partition.dtype != new_values.row_splits.dtype):
1327      if not ragged_config.auto_cast_partition_dtype():
1328        raise ValueError("self and new_values have mismatched row_splits "
1329                         "dtypes; use RaggedTensor.with_row_splits_dtype() to "
1330                         "convert them to compatible dtypes.")
1331      new_values = new_values.with_row_splits_dtype(dtypes.int64)
1332      return self.with_row_splits_dtype(dtypes.int64).with_values(new_values)
1333    return RaggedTensor(
1334        values=new_values, row_partition=self._row_partition, internal=True)
1335
1336  def with_flat_values(self, new_values):
1337    """Returns a copy of `self` with `flat_values` replaced by `new_value`.
1338
1339    Preserves cached row-partitioning tensors such as `self.cached_nrows` and
1340    `self.cached_value_rowids` if they have values.
1341
1342    Args:
1343      new_values: Potentially ragged tensor that should replace
1344        `self.flat_values`.  Must have `rank > 0`, and must have the same number
1345        of rows as `self.flat_values`.
1346
1347    Returns:
1348      A `RaggedTensor`.
1349      `result.rank = self.ragged_rank + new_values.rank`.
1350      `result.ragged_rank = self.ragged_rank + new_values.ragged_rank`.
1351    """
1352    if isinstance(self._values, RaggedTensor):
1353      return self.with_values(self.values.with_flat_values(new_values))
1354    else:
1355      new_values = _convert_to_ragged_tensor_values(new_values)
1356    return self.with_values(new_values)
1357
1358  def with_row_splits_dtype(self, dtype):
1359    """Returns a copy of this RaggedTensor with the given `row_splits` dtype.
1360
1361    For RaggedTensors with multiple ragged dimensions, the `row_splits` for all
1362    nested `RaggedTensor` objects are cast to the given dtype.
1363
1364    Args:
1365      dtype: The dtype for `row_splits`.  One of `tf.int32` or `tf.int64`.
1366
1367    Returns:
1368      A copy of this RaggedTensor, with the `row_splits` cast to the given
1369      type.
1370    """
1371    dtype = dtypes.as_dtype(dtype)
1372    if dtype not in (dtypes.int32, dtypes.int64):
1373      raise ValueError("dtype must be int32 or int64")
1374    if self._row_partition.dtype == dtype:
1375      return self
1376    current_values = self._values
1377    if isinstance(current_values, RaggedTensor):
1378      return RaggedTensor(
1379          values=current_values.with_row_splits_dtype(dtype),
1380          row_partition=self._row_partition.with_row_splits_dtype(dtype),
1381          internal=True)
1382    else:
1383      return RaggedTensor(
1384          values=current_values,
1385          row_partition=self._row_partition.with_row_splits_dtype(dtype),
1386          internal=True)
1387
1388  def merge_dims(self, outer_axis, inner_axis):
1389    """Merges outer_axis...inner_axis into a single dimension.
1390
1391    Returns a copy of this RaggedTensor with the specified range of dimensions
1392    flattened into a single dimension, with elements in row-major order.
1393
1394    #### Examples:
1395
1396    >>> rt = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
1397    >>> print(rt.merge_dims(0, 1))
1398    <tf.RaggedTensor [[1, 2], [3], [4, 5, 6]]>
1399    >>> print(rt.merge_dims(1, 2))
1400    <tf.RaggedTensor [[1, 2, 3], [4, 5, 6]]>
1401    >>> print(rt.merge_dims(0, 2))
1402    tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
1403
1404    To mimic the behavior of `np.flatten` (which flattens all dimensions), use
1405    `rt.merge_dims(0, -1).  To mimic the behavior of `tf.layers.Flatten` (which
1406    flattens all dimensions except the outermost batch dimension), use
1407    `rt.merge_dims(1, -1)`.
1408
1409    Args:
1410      outer_axis: `int`: The first dimension in the range of dimensions to
1411        merge. May be negative if `self.shape.rank` is statically known.
1412      inner_axis: `int`: The last dimension in the range of dimensions to merge.
1413        May be negative if `self.shape.rank` is statically known.
1414
1415    Returns:
1416      A copy of this tensor, with the specified dimensions merged into a
1417      single dimension.  The shape of the returned tensor will be
1418      `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
1419      is the total number of slices in the merged dimensions.
1420    """
1421    outer_axis = array_ops.get_positive_axis(
1422        outer_axis,
1423        self.shape.rank,
1424        axis_name="outer_axis",
1425        ndims_name="rank(self)")
1426    inner_axis = array_ops.get_positive_axis(
1427        inner_axis,
1428        self.shape.rank,
1429        axis_name="inner_axis",
1430        ndims_name="rank(self)")
1431    if not outer_axis <= inner_axis:
1432      raise ValueError("Expected outer_axis (%d) to be less than or equal to "
1433                       "inner_axis (%d)" % (outer_axis, inner_axis))
1434    return merge_dims(self, outer_axis, inner_axis)
1435
1436  def _set_shape(self, shape):
1437    """Updates the static shape of `self` to be `shape`.
1438
1439    * If a dimension of `shape` has known rank, and is encoded via
1440      partitioning, then this will update the corresponding partition to
1441      define `_uniform_row_length` and `nrows`.
1442    * If a dimension of `shape` has a known rank, and is encoded as one
1443      of the `flat_values` dimensions, then `flat_values.set_shape()` will
1444      be used to update its shape.
1445
1446    Warning: Using this method to assert an incorrect shape for a RaggedTensor
1447    (i.e., one that's not consistent with its actual shape) can cause
1448    segmentation faults and very difficult-to-diagnose behavior.  Only use this
1449    method if you are certain that the shape is correct.
1450
1451    Args:
1452      shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`.
1453    """
1454    # TODO(edloper): Refactor this to not directly access private members
1455    # of RowPartition.
1456    # pylint: disable=protected-access
1457
1458    shape = tensor_shape.as_shape(shape)
1459    if shape.rank is None:
1460      return  # Nothing to do.
1461
1462    shape = shape.as_list()
1463
1464    # Outermost dimension
1465    if shape[0] is not None:
1466      self._row_partition._row_splits.set_shape(shape[0] + 1)
1467
1468    # Partitioned dimensions
1469    dtype = self._row_partition.dtype
1470    for i, partition in enumerate(self._nested_row_partitions):
1471      size = shape[i + 1]
1472      if size is not None:
1473        if partition._uniform_row_length is not None:
1474          old_row_length = tensor_util.constant_value(
1475              partition._uniform_row_length)
1476          if old_row_length is not None:
1477            if size == old_row_length:
1478              continue  # already have shape info for this axis.
1479            else:
1480              raise ValueError("Inconsistent size for axis %s: %s vs %s" %
1481                               ((i + 1), old_row_length, size))
1482        partition._uniform_row_length = ops.convert_to_tensor(size, dtype)
1483        if partition._nrows is None:
1484          partition._nrows = array_ops.size(partition._row_splits) - 1
1485
1486    # Inner dimensions
1487    flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:])
1488    self.flat_values.set_shape(flat_shape)
1489
1490#=============================================================================
1491# Tensor Type Conversions
1492#=============================================================================
1493
1494  @classmethod
1495  @dispatch.add_dispatch_support
1496  def from_tensor(cls,
1497                  tensor,
1498                  lengths=None,
1499                  padding=None,
1500                  ragged_rank=1,
1501                  name=None,
1502                  row_splits_dtype=dtypes.int64):
1503    """Converts a `tf.Tensor` into a `RaggedTensor`.
1504
1505    The set of absent/default values may be specified using a vector of lengths
1506    or a padding value (but not both).  If `lengths` is specified, then the
1507    output tensor will satisfy `output[row] = tensor[row][:lengths[row]]`. If
1508    'lengths' is a list of lists or tuple of lists, those lists will be used
1509    as nested row lengths. If `padding` is specified, then any row *suffix*
1510    consisting entirely of `padding` will be excluded from the returned
1511    `RaggedTensor`.  If neither `lengths` nor `padding` is specified, then the
1512    returned `RaggedTensor` will have no absent/default values.
1513
1514    Examples:
1515
1516    >>> dt = tf.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]])
1517    >>> tf.RaggedTensor.from_tensor(dt)
1518    <tf.RaggedTensor [[5, 7, 0], [0, 3, 0], [6, 0, 0]]>
1519    >>> tf.RaggedTensor.from_tensor(dt, lengths=[1, 0, 3])
1520    <tf.RaggedTensor [[5], [], [6, 0, 0]]>
1521
1522    >>> tf.RaggedTensor.from_tensor(dt, padding=0)
1523    <tf.RaggedTensor [[5, 7], [0, 3], [6]]>
1524
1525    >>> dt = tf.constant([[[5, 0], [7, 0], [0, 0]],
1526    ...                   [[0, 0], [3, 0], [0, 0]],
1527    ...                   [[6, 0], [0, 0], [0, 0]]])
1528    >>> tf.RaggedTensor.from_tensor(dt, lengths=([2, 0, 3], [1, 1, 2, 0, 1]))
1529    <tf.RaggedTensor [[[5], [7]], [], [[6, 0], [], [0]]]>
1530
1531    Args:
1532      tensor: The `Tensor` to convert.  Must have rank `ragged_rank + 1` or
1533        higher.
1534      lengths: An optional set of row lengths, specified using a 1-D integer
1535        `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows
1536        in `tensor`).  If specified, then `output[row]` will contain
1537        `tensor[row][:lengths[row]]`.  Negative lengths are treated as zero. You
1538          may optionally pass a list or tuple of lengths to this argument, which
1539          will be used as nested row lengths to construct a ragged tensor with
1540          multiple ragged dimensions.
1541      padding: An optional padding value.  If specified, then any row suffix
1542        consisting entirely of `padding` will be excluded from the returned
1543        RaggedTensor.  `padding` is a `Tensor` with the same dtype as `tensor`
1544        and with `shape=tensor.shape[ragged_rank + 1:]`.
1545      ragged_rank: Integer specifying the ragged rank for the returned
1546        `RaggedTensor`.  Must be greater than zero.
1547      name: A name prefix for the returned tensors (optional).
1548      row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
1549        tensor.  One of `tf.int32` or `tf.int64`.
1550
1551    Returns:
1552      A `RaggedTensor` with the specified `ragged_rank`.  The shape of the
1553      returned ragged tensor is compatible with the shape of `tensor`.
1554    Raises:
1555      ValueError: If both `lengths` and `padding` are specified.
1556    """
1557    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
1558    if lengths is not None and padding is not None:
1559      raise ValueError("Specify lengths or padding, but not both")
1560    if not isinstance(ragged_rank, int):
1561      raise TypeError("ragged_rank expected int, got %r" % ragged_rank)
1562    if ragged_rank <= 0:
1563      raise ValueError("ragged_rank must be greater than 0; got %s" %
1564                       ragged_rank)
1565
1566    with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]):
1567      tensor = ops.convert_to_tensor(tensor, name="tensor")
1568      tensor.shape.with_rank_at_least(ragged_rank + 1)
1569      input_shape = array_ops.shape(tensor, out_type=row_splits_dtype)
1570      ncols = input_shape[1]
1571
1572      # Handle nested row lengths.
1573      if (lengths is not None and isinstance(lengths, (list, tuple)) and
1574          len(lengths) and not isinstance(lengths[0], (int, float))):
1575        if ragged_rank not in (1, len(lengths)):
1576          # Note: we accept `ragged_rank=1` here because it's the default value;
1577          # i.e., if the user passes in a tuple of lengths, but doesn't specify
1578          # ragged_rank, then we should use that tuple to determine ragged_rank.
1579          # We only want to complain if they pass in an explicit ragged_rank
1580          # that doesn't match len(lengths).
1581          raise ValueError("If lengths is a tuple of row_lengths, then "
1582                           "ragged_rank must be len(lengths).")
1583        # Rather than reconstructing the tensor mask directly, we can
1584        # recreate it as a boolean RaggedTensor, then densify that and use
1585        # that as the mask to clear out the unused data in the passed tensor.
1586        tensor.shape.with_rank_at_least(len(lengths) + 1)
1587        num_tokens = math_ops.reduce_sum(lengths[-1])
1588        ones_mask = array_ops.ones([num_tokens], dtype=dtypes.bool)
1589        ragged_mask = cls.from_nested_row_lengths(
1590            ones_mask, lengths, validate=False)
1591        dense_ragged_mask = ragged_mask.to_tensor(default_value=False)
1592        masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask)
1593        return cls.from_nested_row_lengths(masked_data, lengths, validate=False)
1594
1595      # Handle ragged_rank>1 via recursion:
1596      # If the output should have multiple ragged dimensions, then first
1597      # flatten the tensor to eliminate all but the last ragged dimension,
1598      # and recursively convert that flattened tensor.  Then add on the splits
1599      # for the dimensions that we flattened out.
1600      if ragged_rank > 1:
1601        if tensor.shape.is_fully_defined():
1602          input_shape = tensor.shape.as_list()
1603          # The total number of elements in each  dimension.  E.g., if
1604          # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
1605          dim_size = np.cumprod(input_shape)
1606          new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:]
1607        else:
1608          dim_size = math_ops.cumprod(input_shape)
1609          new_shape = array_ops.concat([[dim_size[ragged_rank - 1]],
1610                                        input_shape[ragged_rank:]],
1611                                       axis=0)
1612        flattened = array_ops.reshape(tensor, new_shape)
1613        result = cls.from_tensor(
1614            flattened, lengths, padding, row_splits_dtype=row_splits_dtype)
1615
1616        for axis in range(ragged_rank - 1, 0, -1):
1617          dim_len = tensor_shape.dimension_at_index(tensor.shape, axis).value
1618          if dim_len is None:
1619            dim_len = input_shape[axis]
1620          else:
1621            dim_len = constant_op.constant(dim_len, row_splits_dtype)
1622          result = RaggedTensor.from_uniform_row_length(
1623              values=result,
1624              uniform_row_length=dim_len,
1625              nrows=dim_size[axis - 1],
1626              validate=False)
1627        return result
1628
1629      # If padding was specified, then use it to find row lengths.
1630      if padding is not None:
1631        padding = ops.convert_to_tensor(
1632            padding, name="padding", dtype=tensor.dtype)
1633        padding.shape.assert_is_compatible_with(tensor.shape[2:])
1634
1635        # Find places where the padding is equal to the tensor.  (This will
1636        # broadcast `padding` across the outermost 2 dimensions of `tensor`,
1637        # so `has_default_value.shape = tensor.shape`.)
1638        has_default_value = math_ops.equal(padding, tensor)
1639
1640        # If the padding isn't a scalar, then require that all values in the
1641        # padding match each item in the tensor.  After this block of code,
1642        # `has_default.shape = tensor.shape[:2]`.  (Unfortunately, we can't just
1643        # use reduce_all for both cases, becaue when you pass an empty `axis`
1644        # list to reduce_all, it reduces all axes; but we want it to reduce no
1645        # axes -- i.e., to be a no-op.)
1646        tensor_rank = array_ops.rank(tensor)
1647        reduce_axis = math_ops.range(2, tensor_rank)
1648        has_default = control_flow_ops.cond(
1649            tensor_rank > 2,
1650            lambda: math_ops.reduce_all(has_default_value, axis=reduce_axis),
1651            lambda: has_default_value)
1652        has_default.set_shape(tensor_shape.TensorShape([None, None]))
1653        has_default.set_shape(tensor.shape[:2])
1654
1655        # Use has_default to find the length of each row: for each
1656        # non-default item in a row, calculate the length that the row needs to
1657        # have to include that item; and then take the max of those values
1658        # (across each row).
1659        has_nondefault = math_ops.logical_not(has_default)
1660        has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype)
1661        length_for_nondefault_value = (
1662            has_nondefault *
1663            array_ops.expand_dims(math_ops.range(1, ncols + 1), 0))
1664        lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1)
1665
1666      if lengths is not None:
1667        # If we have lengths (either directly supplied, or computed from
1668        # paddings), then use those to construct splits; and then use masking
1669        # to get the corresponding values.
1670        lengths = ragged_util.convert_to_int_tensor(lengths, "lengths",
1671                                                    row_splits_dtype)
1672        lengths.shape.assert_has_rank(1)
1673        lengths = math_ops.minimum(lengths, ncols)
1674        lengths = math_ops.maximum(lengths, 0)
1675        limits = math_ops.cumsum(lengths)
1676        splits = array_ops.concat(
1677            [array_ops.zeros([1], row_splits_dtype), limits], axis=0)
1678        mask = array_ops.sequence_mask(lengths, maxlen=ncols)
1679        values = array_ops.boolean_mask(tensor, mask)
1680        return cls.from_row_splits(values, splits, validate=False)
1681
1682      # If neither padding nor lengths were specified, then create a splits
1683      # vector that contains no default values, and reshape the input tensor
1684      # to form the values for the RaggedTensor.
1685      values_shape = array_ops.concat([[input_shape[0] * input_shape[1]],
1686                                       input_shape[2:]], axis=0)
1687      values = array_ops.reshape(tensor, values_shape)
1688      const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value
1689      const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value
1690      if const_nrows is not None:
1691        nrows = constant_op.constant(const_nrows, row_splits_dtype)
1692      else:
1693        nrows = input_shape[0]
1694      if const_ncols is not None:
1695        ncols = constant_op.constant(const_ncols, row_splits_dtype)
1696      else:
1697        ncols = input_shape[1]
1698      return RaggedTensor.from_uniform_row_length(
1699          values=values, uniform_row_length=ncols, nrows=nrows, validate=False)
1700
1701  def to_tensor(self, default_value=None, name=None, shape=None):
1702    """Converts this `RaggedTensor` into a `tf.Tensor`.
1703
1704    If `shape` is specified, then the result is padded and/or truncated to
1705    the specified shape.
1706
1707    Examples:
1708
1709    >>> rt = tf.ragged.constant([[9, 8, 7], [], [6, 5], [4]])
1710    >>> print(rt.to_tensor())
1711    tf.Tensor(
1712        [[9 8 7] [0 0 0] [6 5 0] [4 0 0]], shape=(4, 3), dtype=int32)
1713    >>> print(rt.to_tensor(shape=[5, 2]))
1714    tf.Tensor(
1715        [[9 8] [0 0] [6 5] [4 0] [0 0]], shape=(5, 2), dtype=int32)
1716
1717    Args:
1718      default_value: Value to set for indices not specified in `self`. Defaults
1719        to zero.  `default_value` must be broadcastable to
1720        `self.shape[self.ragged_rank + 1:]`.
1721      name: A name prefix for the returned tensors (optional).
1722      shape: The shape of the resulting dense tensor.  In particular,
1723        `result.shape[i]` is `shape[i]` (if `shape[i]` is not None), or
1724        `self.bounding_shape(i)` (otherwise).`shape.rank` must be `None` or
1725        equal to `self.rank`.
1726
1727    Returns:
1728      A `Tensor` with shape `ragged.bounding_shape(self)` and the
1729      values specified by the non-empty values in `self`.  Empty values are
1730      assigned `default_value`.
1731    """
1732    with ops.name_scope(name, "RaggedToTensor", [self, default_value, shape]):
1733      if default_value is not None:
1734        default_value = ops.convert_to_tensor(
1735            default_value, name="default_value", dtype=self.dtype)
1736      type_tensor_pairs = _get_row_partition_type_tensor_pairs(self)
1737      row_partition_types = [x[0] for x in type_tensor_pairs]
1738      row_partition_tensors = [x[1] for x in type_tensor_pairs]
1739      if default_value is None:
1740        default_value = array_ops.zeros((), self.dtype)
1741
1742      if (isinstance(shape, (list, tuple)) and
1743          any(isinstance(v, ops.Tensor) for v in shape) and
1744          all(isinstance(v, (int, ops.Tensor)) for v in shape)):
1745        shape = array_ops.stack(shape)
1746
1747      shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
1748      tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
1749          shape=shape_tensor,
1750          values=self.flat_values,
1751          default_value=default_value,
1752          row_partition_types=row_partition_types,
1753          row_partition_tensors=row_partition_tensors)
1754
1755      ragged_shape = self.shape
1756
1757      if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor):
1758        # Merged self.shape and shape, favoring the second one as it takes
1759        # into account potential padding added to the output.
1760        shape = tensor_shape.as_shape(shape)
1761        if shape.rank is None:
1762          output_shape = ragged_shape
1763        else:
1764          # At this point we can assume that hshape.rank == ragged_shape.rank
1765          # because otherwise it would have failed earlier.
1766          output_shape = [s1 if s1 is not None else s2 for (s1, s2)
1767                          in zip(shape.as_list(), ragged_shape.as_list())]
1768        tensor.set_shape(output_shape)
1769
1770      return tensor
1771
1772  @classmethod
1773  @dispatch.add_dispatch_support
1774  def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
1775    """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`.
1776
1777    Each row of the `output` `RaggedTensor` will contain the explicit values
1778    from the same row in `st_input`.  `st_input` must be ragged-right.  If not
1779    it is not ragged-right, then an error will be generated.
1780
1781    Example:
1782
1783    >>> indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0]]
1784    >>> st = tf.sparse.SparseTensor(indices=indices,
1785    ...                             values=[1, 2, 3, 4, 5],
1786    ...                             dense_shape=[4, 3])
1787    >>> tf.RaggedTensor.from_sparse(st).to_list()
1788    [[1, 2, 3], [4], [], [5]]
1789
1790    Currently, only two-dimensional `SparseTensors` are supported.
1791
1792    Args:
1793      st_input: The sparse tensor to convert.  Must have rank 2.
1794      name: A name prefix for the returned tensors (optional).
1795      row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
1796        tensor.  One of `tf.int32` or `tf.int64`.
1797
1798    Returns:
1799      A `RaggedTensor` with the same values as `st_input`.
1800      `output.ragged_rank = rank(st_input) - 1`.
1801      `output.shape = [st_input.dense_shape[0], None]`.
1802    Raises:
1803      ValueError: If the number of dimensions in `st_input` is not known
1804        statically, or is not two.
1805    """
1806    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
1807    if not sparse_tensor.is_sparse(st_input):
1808      raise TypeError("Expected SparseTensor, got %s" % type(st_input).__name__)
1809    with ops.name_scope(name, "RaggedFromSparse", [st_input]):
1810      st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor(
1811          st_input, name="st_input")
1812
1813      if st_input.dense_shape.shape.ndims is None:
1814        static_rank_from_dense_shape = None
1815      else:
1816        static_rank_from_dense_shape = st_input.dense_shape.shape.dims[0].value
1817
1818      if st_input.indices.shape.ndims is None:
1819        static_rank_from_indices = None
1820      else:
1821        static_rank_from_indices = st_input.indices.shape.dims[1].value
1822
1823      if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2:
1824        raise ValueError("rank(st_input) must be 2")
1825
1826      with ops.control_dependencies(
1827          _assert_sparse_indices_are_ragged_right(st_input.indices)):
1828        # Treat sparse row indices as segment ids to generate a splits tensor
1829        # thta we can pair with the sparse tensor values.  (Ignore sparse column
1830        # indices.)
1831        segment_ids = math_ops.cast(st_input.indices[:, 0], row_splits_dtype)
1832        num_segments = math_ops.cast(st_input.dense_shape[0], row_splits_dtype)
1833        return cls.from_value_rowids(
1834            st_input.values, segment_ids, num_segments, validate=False)
1835
1836  def to_sparse(self, name=None):
1837    """Converts this `RaggedTensor` into a `tf.sparse.SparseTensor`.
1838
1839    Example:
1840
1841    >>> rt = tf.ragged.constant([[1, 2, 3], [4], [], [5, 6]])
1842    >>> print(rt.to_sparse())
1843    SparseTensor(indices=tf.Tensor(
1844                     [[0 0] [0 1] [0 2] [1 0] [3 0] [3 1]],
1845                     shape=(6, 2), dtype=int64),
1846                 values=tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32),
1847                 dense_shape=tf.Tensor([4 3], shape=(2,), dtype=int64))
1848
1849    Args:
1850      name: A name prefix for the returned tensors (optional).
1851
1852    Returns:
1853      A SparseTensor with the same values as `self`.
1854    """
1855    with ops.name_scope(name, "RaggedToSparse", [self]):
1856      result = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
1857          self.nested_row_splits, self.flat_values, name=name)
1858      return sparse_tensor.SparseTensor(result.sparse_indices,
1859                                        result.sparse_values,
1860                                        result.sparse_dense_shape)
1861
1862  @classmethod
1863  def _from_variant(cls,
1864                    variant,
1865                    dtype,
1866                    output_ragged_rank,
1867                    input_ragged_rank=None,
1868                    row_splits_dtype=dtypes.int64,
1869                    name=None):
1870    """Converts a `variant` Tensor into a `RaggedTensor`.
1871
1872    The input `variant` could be a scalar, meaning it encodes a single
1873    `RaggedTensor` with ragged_rank `output_ragged_rank`. Alternatively it could
1874    have an arbitrary rank, in which case each element is decoded into a
1875    `RaggedTensor` with ragged_rank `input_ragged_rank` and these are then
1876    stacked according to the input shape to output a single `RaggedTensor`
1877    with ragged_rank `output_ragged_rank`. If `input_ragged_rank` is not
1878    provided, it is inferred dynamically as `output_ragged_rank` -
1879    `rank(variant)`. If `input_ragged_rank` is provided, the following must be
1880    true: `output_ragged_rank` = `input_ragged_rank` + `rank(variant)`.
1881
1882    Example:
1883
1884    >>> rt = tf.ragged.constant([[0], [1, 2]])
1885    >>> et = rt._to_variant()
1886    >>> stacked_et = tf.stack([et, et])
1887    >>> tf.RaggedTensor._from_variant(  # scalar input.
1888    ...     et, dtype=tf.int32, output_ragged_rank=1).to_list()
1889    [[0], [1, 2]]
1890    >>> tf.RaggedTensor._from_variant(  # batched input.
1891    ...     stacked_et, dtype=tf.int32, output_ragged_rank=2).to_list()
1892    [[[0], [1, 2]], [[0], [1, 2]]]
1893
1894    Args:
1895      variant: A `variant` Tensor representing an encoded (possibly
1896        nested-batched) `RaggedTensor`.
1897      dtype: The dtype of the encoded `RaggedTensor`.
1898      output_ragged_rank: The expected ragged rank of the output `RaggedTensor`.
1899      input_ragged_rank: The ragged rank of each encoded `RaggedTensor`. This is
1900        optional and inferred dynamically if not provided.
1901      row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
1902        of `tf.int32` or `tf.int64`.
1903      name: A name prefix for the returned tensors (optional).
1904
1905    Returns:
1906      A `RaggedTensor` of dtype `dtype` and ragged rank `output_ragged_rank`.
1907
1908    Raises:
1909      ValueError: If the input rank is known, `input_ragged_rank` is provided
1910          and `output_ragged_rank` = `input_ragged_rank` + `rank(variant)` does
1911          not hold.
1912    """
1913    variant = ops.convert_to_tensor(
1914        variant, name="variant", dtype=dtypes.variant)
1915    if (variant.shape.ndims is not None and input_ragged_rank is not None and
1916        output_ragged_rank != input_ragged_rank + variant.shape.ndims):
1917      raise ValueError(
1918          "output_ragged_rank must be equal to input_ragged_rank +"
1919          "variant.shape.ndims, found variant.shape.ndims: %d, "
1920          "input_ragged_rank: %d, output_ragged_rank: %d" %
1921          (variant.shape.ndims, input_ragged_rank, output_ragged_rank))
1922    input_ragged_rank = -1 if input_ragged_rank is None else input_ragged_rank
1923    with ops.name_scope(
1924        name, "RaggedFromVariant",
1925        [variant, dtype, input_ragged_rank, output_ragged_rank]):
1926      result = gen_ragged_conversion_ops.ragged_tensor_from_variant(
1927          variant, input_ragged_rank, output_ragged_rank, dtype,
1928          row_splits_dtype, name)
1929      return cls.from_nested_row_splits(
1930          result.output_dense_values,
1931          result.output_nested_splits,
1932          validate=False)
1933
1934  def _to_variant(self, batched_input=False, name=None):
1935    """Converts this `RaggedTensor` into a `variant` Tensor.
1936
1937    If `batched_input` is `True`, then the `RaggedTensor` is unbatched along the
1938    zero-th dimension, each component `RaggedTensor` is encoded into a scalar
1939    `variant` Tensor, and these are stacked to return a 1-D `variant` Tensor.
1940    If `batched_input` is `False`, then the `RaggedTensor` is encoded as is and
1941    a scalar `variant` Tensor is returned.
1942
1943    Example:
1944    >>> rt = tf.ragged.constant([[[0]], [[1]], [[2]]])
1945    >>> rt._to_variant().shape.as_list()
1946    []
1947    >>> rt._to_variant(batched_input=True).shape.as_list()
1948    [3]
1949
1950    Args:
1951      batched_input: If `True`, the `RaggedTensor` is unbatched and converted to
1952        a `variant` vector. Set to `False` by default.
1953      name: A name prefix for the returned tensors (optional).
1954
1955    Returns:
1956      A `variant` Tensor that encodes this `RaggedTensor`.
1957    """
1958    with ops.name_scope(name, "RaggedToVariant", [self, batched_input]):
1959      return gen_ragged_conversion_ops.ragged_tensor_to_variant(
1960          self.nested_row_splits, self.flat_values, batched_input, name)
1961
1962  #=============================================================================
1963  # String Encoding
1964  #=============================================================================
1965  def __repr__(self):
1966    if self._is_eager():
1967      return "<tf.RaggedTensor %s>" % self.to_list()
1968    else:
1969      return "tf.RaggedTensor(values=%s, row_splits=%s)" % (
1970          self.values, self.row_splits)
1971
1972  #=============================================================================
1973  # Eager Execution Mode
1974  #=============================================================================
1975
1976  def numpy(self):
1977    """Returns a numpy `array` with the values for this `RaggedTensor`.
1978
1979    Requires that this `RaggedTensor` was constructed in eager execution mode.
1980
1981    Ragged dimensions are encoded using numpy `arrays` with `dtype=object` and
1982    `rank=1`, where each element is a single row.
1983
1984    #### Examples
1985
1986    In the following example, the value returned by `RaggedTensor.numpy()`
1987    contains three numpy `array` objects: one for each row (with `rank=1` and
1988    `dtype=int64`), and one to combine them (with `rank=1` and `dtype=object`):
1989
1990    >>> tf.ragged.constant([[1, 2, 3], [4, 5]], dtype=tf.int64).numpy()
1991    array([array([1, 2, 3]), array([4, 5])], dtype=object)
1992
1993    Uniform dimensions are encoded using multidimensional numpy `array`s.  In
1994    the following example, the value returned by `RaggedTensor.numpy()` contains
1995    a single numpy `array` object, with `rank=2` and `dtype=int64`:
1996
1997    >>> tf.ragged.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.int64).numpy()
1998    array([[1, 2, 3], [4, 5, 6]])
1999
2000    Returns:
2001      A numpy `array`.
2002    """
2003    if not self._is_eager():
2004      raise ValueError("RaggedTensor.numpy() is only supported in eager mode.")
2005    values = self.values.numpy()
2006    splits = self.row_splits.numpy()
2007    rows = [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)]
2008    if not rows:
2009      return np.zeros((0, 0) + values.shape[1:], dtype=values.dtype)
2010    # Note: if `rows` have ragged lengths, then they will be stored in a
2011    # np.ndarray with dtype=object and rank=1.  If they have uniform lengths,
2012    # they will be combined into a single np.ndarray with dtype=row.dtype and
2013    # rank=row.rank+1.
2014    return np.array(rows)
2015
2016  def to_list(self):
2017    """Returns a nested Python `list` with the values for this `RaggedTensor`.
2018
2019    Requires that `rt` was constructed in eager execution mode.
2020
2021    Returns:
2022      A nested Python `list`.
2023    """
2024    if self._is_eager():
2025      return self._eager_value().to_list()
2026    else:
2027      raise ValueError("RaggedTensor.to_list() is only supported in eager "
2028                       "mode; in graph mode, evaluate the RaggedTensor first "
2029                       "and then use RaggedTensorValue.to_list().")
2030
2031  def _eager_value(self):
2032    """Returns a RaggedTensorValue for self.  Requires self._is_eager()=true."""
2033    value = self.flat_values.numpy()
2034    for row_splits in reversed(self.nested_row_splits):
2035      value = ragged_tensor_value.RaggedTensorValue(value, row_splits.numpy())
2036    return value
2037
2038  def _is_eager(self):
2039    """Returns True if values & row_splits Tensors are all `EagerTensor`s."""
2040    rt = self
2041    while isinstance(rt, RaggedTensor):
2042      if not isinstance(rt.row_splits, ops.EagerTensor):
2043        return False
2044      rt = rt.values
2045    return isinstance(rt, ops.EagerTensor)
2046
2047  #=============================================================================
2048  # Operators
2049  #=============================================================================
2050  # To avoid circular dependencies, we define stub methods for operators here,
2051  # and then override them when the ragged_operators module is imported.
2052
2053  def _overloaded_operator(name):  # pylint: disable=no-self-argument
2054    def stub(*args, **kwargs):
2055      del args, kwargs
2056      raise ValueError(
2057          "You must import 'tensorflow.python.ops.ragged.ragged_ops' "
2058          "before using RaggedTensor.%s" % name)
2059    return stub
2060
2061  __getitem__ = _overloaded_operator("__getitem__")
2062  __ge__ = _overloaded_operator("__ge__")
2063  __gt__ = _overloaded_operator("__gt__")
2064  __le__ = _overloaded_operator("__le__")
2065  __lt__ = _overloaded_operator("__lt__")
2066  __and__ = _overloaded_operator("__and__")
2067  __rand__ = _overloaded_operator("__rand__")
2068  __invert__ = _overloaded_operator("__invert__")
2069  __ror__ = _overloaded_operator("__ror__")
2070  __or__ = _overloaded_operator("__or__")
2071  __xor__ = _overloaded_operator("__xor__")
2072  __rxor__ = _overloaded_operator("__rxor__")
2073  __abs__ = _overloaded_operator("__abs__")
2074  __add__ = _overloaded_operator("__add__")
2075  __radd__ = _overloaded_operator("__radd__")
2076  __div__ = _overloaded_operator("__div__")
2077  __rdiv__ = _overloaded_operator("__rdiv__")
2078  __floordiv__ = _overloaded_operator("__floordiv__")
2079  __rfloordiv__ = _overloaded_operator("__rfloordiv__")
2080  __mod__ = _overloaded_operator("__mod__")
2081  __rmod__ = _overloaded_operator("__rmod__")
2082  __mul__ = _overloaded_operator("__mul__")
2083  __rmul__ = _overloaded_operator("__rmul__")
2084  __neg__ = _overloaded_operator("__neg__")
2085  __pow__ = _overloaded_operator("__pow__")
2086  __rpow__ = _overloaded_operator("__rpow__")
2087  __sub__ = _overloaded_operator("__sub__")
2088  __rsub__ = _overloaded_operator("__rsub__")
2089  __truediv__ = _overloaded_operator("__truediv__")
2090  __rtruediv__ = _overloaded_operator("__rtruediv__")
2091  del _overloaded_operator
2092
2093  #=============================================================================
2094  # Name Scope
2095  #=============================================================================
2096
2097  # This private function is used by ops.name_scope to ensure that all of the
2098  # input tensors for the scope belong to the same graph.  Defining this means
2099  # that you may include `RaggedTensor` objects in the name_scope `values`
2100  # list.
2101  def _as_graph_element(self):
2102    """Convert `self` to a graph element."""
2103    values = self.values
2104    while isinstance(values, RaggedTensor):
2105      values = values.values
2106    return values
2107
2108  #=============================================================================
2109  # Composite Tensor
2110  #=============================================================================
2111
2112  @property
2113  def _type_spec(self):
2114    return RaggedTensorSpec.from_value(self)
2115
2116  def _shape_invariant_to_type_spec(self, shape):
2117    return RaggedTensorSpec(shape, self.dtype, self.ragged_rank,
2118                            self.row_splits.dtype)
2119
2120  def consumers(self):
2121    return self._consumers()
2122
2123
2124def is_ragged(value):
2125  """Returns true if `value` is a ragged tensor or ragged tensor value."""
2126  return isinstance(value,
2127                    (RaggedTensor, ragged_tensor_value.RaggedTensorValue))
2128
2129
2130def match_row_splits_dtypes(*tensors, **kwargs):
2131  """Return a copy of `tensors` with row_splits all having the same dtype.
2132
2133  Args:
2134    *tensors: A list of Tensors or RaggedTensors.
2135    **kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors),
2136      where `dtype` is the data type used by row-splits, and `tensors` is the
2137      converted list of `Tensors` and `RaggedTensors`.
2138
2139  Returns:
2140    The converted list of `Tensors` and `RaggedTensors`.
2141  """
2142  return_dtype = kwargs.pop("return_dtype", False)
2143  if kwargs:
2144    raise ValueError("Unexpected keyword args %r" % kwargs)
2145
2146  has_int32 = False
2147  has_int64 = False
2148  for tensor in tensors:
2149    if isinstance(tensor, RaggedTensor):
2150      if tensor.row_splits.dtype == dtypes.int32:
2151        has_int32 = True
2152      else:
2153        has_int64 = True
2154
2155  if has_int32 and has_int64:
2156    if not ragged_config.auto_cast_partition_dtype():
2157      raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
2158                       "use RaggedTensor.with_row_splits_dtype() to convert "
2159                       "them to compatible dtypes.")
2160    dtype = dtypes.int64
2161    tensors = tuple(
2162        t.with_row_splits_dtype(dtypes.int64) if isinstance(t, RaggedTensor
2163                                                           ) else t
2164        for t in tensors)
2165
2166  elif has_int32:
2167    dtype = dtypes.int32
2168  else:
2169    dtype = dtypes.int64
2170
2171  if return_dtype:
2172    return (dtype, tensors)
2173  else:
2174    return tensors
2175
2176
2177#===============================================================================
2178# RaggedTensorSpec
2179#===============================================================================
2180@tf_export("RaggedTensorSpec")
2181@type_spec.register("tf.RaggedTensorSpec")
2182class RaggedTensorSpec(type_spec.BatchableTypeSpec):
2183  """Type specification for a `tf.RaggedTensor`."""
2184
2185  __slots__ = [
2186      "_shape", "_dtype", "_ragged_rank", "_row_splits_dtype",
2187      "_flat_values_spec"
2188  ]
2189
2190  @property
2191  def dtype(self):
2192    """The `tf.dtypes.DType` specified by this type for the RaggedTensor.
2193
2194    Examples:
2195
2196    >>> rt = tf.ragged.constant([["a"], ["b", "c"]], dtype=tf.string)
2197    >>> tf.type_spec_from_value(rt).dtype
2198    tf.string
2199
2200    Returns:
2201      A `tf.dtypes.DType` of the values in the RaggedTensor.
2202    """
2203    return self._dtype
2204
2205  @property
2206  def shape(self):
2207    """The statically known shape of the RaggedTensor.
2208
2209    Examples:
2210
2211    >>> rt = tf.ragged.constant([[0], [1, 2]])
2212    >>> tf.type_spec_from_value(rt).shape
2213    TensorShape([2, None])
2214
2215    >>> rt = tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1)
2216    >>> tf.type_spec_from_value(rt).shape
2217    TensorShape([2, None, 2])
2218
2219    Returns:
2220      A `tf.TensorShape` containing the statically known shape of the
2221      RaggedTensor. Ragged dimensions have a size of `None`.
2222    """
2223    return self._shape
2224
2225  @property
2226  def ragged_rank(self):
2227    """The number of times the RaggedTensor's flat_values is partitioned.
2228
2229    Defaults to `shape.ndims - 1`.
2230
2231    Examples:
2232
2233    >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
2234    >>> tf.type_spec_from_value(values).ragged_rank
2235    1
2236
2237    >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
2238    >>> tf.type_spec_from_value(rt1).ragged_rank
2239    2
2240
2241    Returns:
2242      A Python `int` indicating the number of times the underlying `flat_values`
2243      Tensor has been partitioned to add a new dimension.
2244      I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
2245    """
2246    return self._ragged_rank
2247
2248  @property
2249  def row_splits_dtype(self):
2250    """The `tf.dtypes.DType` of the RaggedTensor's `row_splits`.
2251
2252    Examples:
2253
2254    >>> rt = tf.ragged.constant([[1, 2, 3], [4]], row_splits_dtype=tf.int64)
2255    >>> tf.type_spec_from_value(rt).row_splits_dtype
2256    tf.int64
2257
2258    Returns:
2259      A `tf.dtypes.DType` for the RaggedTensor's `row_splits` tensor. One
2260      of `tf.int32` or `tf.int64`.
2261    """
2262    return self._row_splits_dtype
2263
2264  @property
2265  def flat_values_spec(self):
2266    """The `TypeSpec` of the flat_values of RaggedTensor.
2267
2268    Returns:
2269      - The TypeSpec of flat_values.
2270      - None when the flat_values is a Tensor.
2271    """
2272    return self._flat_values_spec
2273
2274  @property
2275  def value_type(self):
2276    return RaggedTensor if self._ragged_rank > 0 else ops.Tensor
2277
2278  def __init__(self,
2279               shape=None,
2280               dtype=dtypes.float32,
2281               ragged_rank=None,
2282               row_splits_dtype=dtypes.int64,
2283               flat_values_spec=None):
2284    """Constructs a type specification for a `tf.RaggedTensor`.
2285
2286    Args:
2287      shape: The shape of the RaggedTensor, or `None` to allow any shape.  If a
2288        shape is specified, then all ragged dimensions must have size `None`.
2289      dtype: `tf.DType` of values in the RaggedTensor.
2290      ragged_rank: Python integer, the number of times the RaggedTensor's
2291        flat_values is partitioned.  Defaults to `shape.ndims - 1`.
2292      row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
2293        of `tf.int32` or `tf.int64`.
2294      flat_values_spec: TypeSpec for flat_value of the RaggedTensor. It shall be
2295        provided when the flat_values is a CompositeTensor rather then Tensor.
2296        If both `dtype` and `flat_values_spec` and  are provided, `dtype` must
2297        be the same as `flat_values_spec.dtype`. (experimental)
2298    """
2299    self._shape = tensor_shape.as_shape(shape)
2300    self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
2301    if flat_values_spec is not None:
2302      if dtype is None:
2303        dtype = flat_values_spec.dtype
2304      elif dtype != flat_values_spec.dtype:
2305        raise ValueError("dtype must be the same as flat_values_spec.dtype")
2306    elif dtype is None:
2307      raise ValueError(
2308          "At least one of dtype or flat_values_spec must be provided")
2309    self._dtype = dtypes.as_dtype(dtype)
2310    self._flat_values_spec = flat_values_spec
2311
2312    rank = self._shape.ndims
2313    if ragged_rank is None:
2314      if rank is None:
2315        raise ValueError("Must specify ragged_rank or "
2316                         "a shape with a known rank.")
2317      ragged_rank = rank - 1
2318    self._ragged_rank = ragged_rank
2319    if not isinstance(self._ragged_rank, int):
2320      raise TypeError("ragged_rank must be an int")
2321
2322    if rank is not None:
2323      if ragged_rank >= rank:
2324        raise ValueError("ragged_rank must be less than rank.")
2325
2326  def is_compatible_with(self, spec_or_value):
2327    # RaggedTensor with ragged_rank 0 can be compatible with raw flat_values.
2328    if self._ragged_rank == 0:
2329      if self._flat_values_spec is None:
2330        if isinstance(spec_or_value, (ops.Tensor, tensor_spec.TensorSpec)):
2331          return tensor_spec.TensorSpec(
2332              self._shape, self._dtype).is_compatible_with(spec_or_value)
2333      elif not isinstance(spec_or_value, (RaggedTensor, RaggedTensorSpec)):
2334        return self._flat_values_spec.is_compatible_with(spec_or_value)
2335    return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value)
2336
2337  def _serialize(self):
2338    if self._flat_values_spec is None:
2339      return (self._shape, self._dtype, self._ragged_rank,
2340              self._row_splits_dtype)
2341    else:
2342      return (self._shape, self._dtype, self._ragged_rank,
2343              self._row_splits_dtype, self._flat_values_spec)
2344
2345  @property
2346  def _component_specs(self):
2347    if self._ragged_rank == 0:
2348      if self._flat_values_spec is not None:
2349        return [self._flat_values_spec]
2350      else:
2351        return [tensor_spec.TensorSpec(self._shape, self._dtype)]
2352
2353    flat_values_spec = self._flat_values_spec
2354    if flat_values_spec is None:
2355      flat_values_shape = tensor_shape.TensorShape([None]).concatenate(
2356          self._shape[self._ragged_rank + 1:])
2357      flat_values_spec = tensor_spec.TensorSpec(flat_values_shape, self._dtype)
2358    outer_dim = tensor_shape.dimension_at_index(self._shape, 0)
2359    outer_splits_shape = [None if outer_dim is None else outer_dim + 1]
2360    inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype)
2361
2362    specs = ([
2363        flat_values_spec,
2364        tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype)
2365    ] + [inner_splits_spec for _ in range(self._ragged_rank - 1)])
2366    return specs
2367
2368  def _to_components(self, value):
2369    if is_ragged(value):
2370      return [value.flat_values] + list(value.nested_row_splits)
2371    else:
2372      return [value]
2373
2374  def _from_components(self, tensor_list):
2375    result = tensor_list[0]
2376    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
2377        not tf2.enabled()):
2378      for row_splits in reversed(tensor_list[1:]):
2379        result = ragged_tensor_value.RaggedTensorValue(result, row_splits)
2380    else:
2381      if isinstance(tensor_list[0], np.ndarray):
2382        tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
2383        result = tensor_list[0]
2384      for row_splits in reversed(tensor_list[1:]):
2385        result = RaggedTensor(
2386            result,
2387            RowPartition.from_row_splits(row_splits, validate=False),
2388            internal=True)
2389    return result
2390
2391  # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops
2392  # to (un)box the component tensors in a way that allows for batching &
2393  # unbatching.
2394  @property
2395  def _flat_tensor_specs(self):
2396    # NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is
2397    # `[]` (scalar), but a `RaggedTensorSpec` can also represent a batch of
2398    # boxed `RaggedTensor` objects with shape `(...)` (and batches of batches,
2399    # etc.), so the flat shape must be unknown.
2400    return [tensor_spec.TensorSpec(None, dtypes.variant)]
2401
2402  def _to_tensor_list(self, value):
2403    # TODO(edloper): Update gen_ragged_conversion_ops that convert to and
2404    # from variant to include all of the row-partitioning tensors.
2405    if self._flat_values_spec is not None:
2406      raise ValueError("Customized value_type is not supported")
2407    ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
2408    if ragged_rank != self._ragged_rank:
2409      raise ValueError("Ragged rank of value (%d) does not match ragged "
2410                       "rank of type (%d)" % (ragged_rank, self._ragged_rank))
2411    if ragged_rank == 0:
2412      return [
2413          gen_ragged_conversion_ops.ragged_tensor_to_variant(
2414              (), value, batched_input=False)
2415      ]
2416    # pylint: disable=protected-access
2417    return [value._to_variant(batched_input=False)]
2418
2419  def _to_batched_tensor_list(self, value):
2420    if self._flat_values_spec is not None:
2421      raise ValueError("Customized value_type is not supported")
2422    ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
2423    if ragged_rank != self._ragged_rank:
2424      raise ValueError("Ragged rank of value (%d) does not match ragged "
2425                       "rank of type (%d)" % (ragged_rank, self._ragged_rank))
2426    if ragged_rank == 0:
2427      # TODO(b/141789000) Update this to handle ragged_rank=0.
2428      raise ValueError(
2429          "_to_batched_tensor_list doesn't support ragged_rank=0 yet")
2430    # pylint: disable=protected-access
2431    return [value._to_variant(batched_input=True)]
2432
2433  def _from_compatible_tensor_list(self, tensor_list):
2434    if self._flat_values_spec is not None:
2435      raise ValueError("Customized value_type is not supported")
2436    if self._ragged_rank < 0:
2437      raise ValueError("ragged_rank must be non-negative; got %s." %
2438                       self._ragged_rank)
2439    result = RaggedTensor._from_variant(  # pylint: disable=protected-access
2440        tensor_list[0],
2441        dtype=self._dtype,
2442        row_splits_dtype=self._row_splits_dtype,
2443        output_ragged_rank=self._ragged_rank)
2444    if self._shape.ndims is not None:
2445      if isinstance(result, RaggedTensor):
2446        outer_dim = tensor_shape.dimension_value(self._shape[0])
2447        if outer_dim is not None:
2448          result.row_splits.set_shape([outer_dim + 1])
2449        result._set_shape(self._shape)  # pylint: disable=protected-access
2450      else:
2451        result.set_shape(self._shape)
2452    return result
2453
2454  def _batch(self, batch_size):
2455    if self._flat_values_spec is not None:
2456      raise ValueError("Customized value_type is not supported")
2457    return RaggedTensorSpec(
2458        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
2459        self._dtype, self._ragged_rank + 1, self._row_splits_dtype)
2460
2461  def _unbatch(self):
2462    if self._flat_values_spec is not None:
2463      raise ValueError("Customized value_type is not supported")
2464    # Note: Negative ragged_rank is allowed here because the dataset could be
2465    # subsequently batched again. If ragged_rank > 1, assume row_splits_dtype is
2466    # consistent. Errors are handled in
2467    # RaggedTensorSpec._from_compatible_tensor_list()
2468    return RaggedTensorSpec(self._shape[1:], self._dtype, self._ragged_rank - 1,
2469                            self._row_splits_dtype)
2470
2471  def _to_legacy_output_types(self):
2472    return self._dtype
2473
2474  def _to_legacy_output_shapes(self):
2475    return self._shape
2476
2477  def _to_legacy_output_classes(self):
2478    return self
2479
2480  @classmethod
2481  def from_value(cls, value):
2482    if (isinstance(value, ragged_tensor_value.RaggedTensorValue) or
2483        isinstance(value.flat_values, ops.Tensor)):
2484      return cls(
2485          shape=value.shape,
2486          dtype=value.values.dtype,
2487          ragged_rank=value.ragged_rank,
2488          row_splits_dtype=value.row_splits.dtype)
2489    else:
2490      return cls(
2491          shape=value.shape,
2492          dtype=value.values.dtype,
2493          ragged_rank=value.ragged_rank,
2494          row_splits_dtype=value.row_splits.dtype,
2495          flat_values_spec=type_spec.type_spec_from_value(value.flat_values))
2496
2497
2498type_spec.register_type_spec_from_value_converter(
2499    ragged_tensor_value.RaggedTensorValue, RaggedTensorSpec.from_value)
2500
2501
2502#===============================================================================
2503# Convert value -> tensor
2504#===============================================================================
2505def convert_to_tensor_or_ragged_tensor(value,
2506                                       dtype=None,
2507                                       preferred_dtype=None,
2508                                       name=None):
2509  """Converts value to a `RaggedTensor` or `Tensor`.
2510
2511  * If `value` is a `RaggedTensor`, then return it as-is.
2512  * If `value` is a `RaggedTensorValue`, return a corresponding constant
2513    `RaggedTensor`.
2514  * Otherwise, use `convert_to_tensor` to convert `value` to a `Tensor`.
2515
2516  Args:
2517    value: A `RaggedTensor`, a `RaggedTensorValue`, or an object whose type has
2518      a registered `Tensor` conversion function.
2519    dtype: Optional element type for the returned tensor.  If missing the type
2520      is inferred from the type of `value`.
2521    preferred_dtype: Optional element type for the returned tensor, used when
2522      dtype is None.  This argument has no effect if `value` is already a
2523      tensor, or when conversion is not possible.
2524    name: Optional name to use if a new `Tensor` is created.
2525
2526  Returns:
2527    A `Tensor` or `RaggedTensor`.
2528  """
2529  if isinstance(value, RaggedTensor):
2530    if dtype and not dtype.is_compatible_with(value.dtype):
2531      raise ValueError("Tensor conversion requested dtype %s for "
2532                       "RaggedTensor with dtype %s: %r" %
2533                       (dtype.name, value.dtype.name, value))
2534    return value
2535  elif isinstance(value, ragged_tensor_value.RaggedTensorValue):
2536    with ops.name_scope(name, "ConvertToTensorOrRaggedTensor", []):
2537      flat_values = ops.convert_to_tensor(
2538          value=value.flat_values,
2539          dtype=dtype,
2540          preferred_dtype=preferred_dtype,
2541          name="flat_values")
2542      return RaggedTensor.from_nested_row_splits(
2543          flat_values, value.nested_row_splits, validate=False)
2544  else:
2545    return ops.convert_to_tensor_v2_with_dispatch(
2546        value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name)
2547
2548
2549def _convert_to_ragged_tensor_values(value):
2550  """Converts value to supported RaggedTensor value.
2551
2552  * If `value` is an object of supported value type, then return it as-is.
2553  * Otherwise convert it to Tensor or RaggedTensor.
2554
2555  Args:
2556    value: An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
2557      value types, or an object whose type has a registered `Tensor`
2558      conversion function.
2559
2560  Returns:
2561    An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
2562    value types
2563  """
2564  if _is_supported_ragged_values_type(value):
2565    return value
2566  else:
2567    return convert_to_tensor_or_ragged_tensor(value, name="values")
2568
2569
2570#===============================================================================
2571# Register RaggedTensor for use with session.run.
2572#===============================================================================
2573def _ragged_tensor_value_from_components(components):
2574  components = list(components)
2575  value = components.pop()
2576  while components:
2577    value = ragged_tensor_value.RaggedTensorValue(value, components.pop())
2578  return value
2579
2580
2581def _ragged_tensor_session_fetch(rt):
2582  components = rt.nested_row_splits + (rt.flat_values,)
2583  return (components, _ragged_tensor_value_from_components)
2584
2585
2586def _ragged_tensor_session_feed(feed_key, feed_val):
2587  key_components = feed_key.nested_row_splits + (feed_key.flat_values,)
2588  val_components = feed_val.nested_row_splits + (feed_val.flat_values,)
2589  return zip(key_components, val_components)
2590
2591
2592def _ragged_tensor_session_feed_for_partial_run(feed_key):
2593  return feed_key.nested_row_splits + (feed_key.flat_values,)
2594
2595
2596session.register_session_run_conversion_functions(
2597    RaggedTensor, _ragged_tensor_session_fetch, _ragged_tensor_session_feed,
2598    _ragged_tensor_session_feed_for_partial_run)
2599
2600
2601#===============================================================================
2602# RaggedTensorType
2603#===============================================================================
2604class RaggedTensorType(object):
2605  """Encoding of a static type for a `RaggedTensor`.
2606
2607  Use this type to express/declare that an output must have the type of
2608  `RaggedTensor`.
2609  """
2610
2611  def __init__(self, dtype, ragged_rank, row_splits_dtype=dtypes.int64):
2612    """Initializes a RaggedTensorType object.
2613
2614    Args:
2615      dtype: data type of the `RaggedTensor`'s inner values.
2616      ragged_rank: ragged_rank of the declared `RaggedTensor`.
2617      row_splits_dtype: data type for the `RaggedTensor`'s row splits.
2618        One of: `tf.int32` or `tf.int64`.
2619    """
2620    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
2621    self._dtype = dtype
2622    self._ragged_rank = ragged_rank
2623    self._row_splits_dtype = row_splits_dtype
2624
2625  dtype = property(lambda self: self._dtype)
2626  ragged_rank = property(lambda self: self._ragged_rank)
2627  row_splits_dtype = property(lambda self: self._row_splits_dtype)
2628
2629  def __repr__(self):
2630    return "RaggedTensorType(%r, %r, %r)" % (
2631        self.dtype, self.ragged_rank, self.row_splits_dtype)
2632
2633
2634#===============================================================================
2635# Helper Functions
2636#===============================================================================
2637def _assert_sparse_indices_are_ragged_right(indices):
2638  """Checks that the given SparseTensor.indices tensor is ragged-right.
2639
2640  Example: `indices = [[0, 0], [0, 1], [2, 0], [3, 1]]` is not ragged right
2641  because the entry `[3, 1]` skips a cell.
2642
2643  Args:
2644    indices: The SparseTensor indices to check.
2645
2646  Returns:
2647    A list of control dependency op tensors.
2648  """
2649  index_prefix = indices[:, :-1]
2650  index_suffix = indices[:, -1]
2651
2652  # Check whether each index is starting a new row in the innermost dimension
2653  # (prefix[i] != prefix[i-1]) or continuing a row (prefix[i] == prefix[i-1]).
2654  # (Note: this skips the first index; we will check that separately below.)
2655  index_prefix_changed = math_ops.reduce_any(
2656      math_ops.not_equal(index_prefix[1:], index_prefix[:-1]), axis=1)
2657
2658  # Check two cases:
2659  #   * For indices that start a new row: index_suffix[i] must be zero.
2660  #   * For indices that continue a row: index_suffix[i] must be equal to
2661  #     index_suffix[i-1]+1.
2662  index_ok = array_ops.where(
2663      index_prefix_changed, math_ops.equal(index_suffix[1:], 0),
2664      math_ops.equal(index_suffix[1:], index_suffix[:-1] + 1))
2665
2666  # Also check that the very first index didn't skip any cells.  The first
2667  # index starts a new row (by definition), so its suffix should be zero.
2668  sparse_indices_are_ragged_right = math_ops.logical_and(
2669      math_ops.reduce_all(math_ops.equal(index_suffix[:1], 0)),
2670      math_ops.reduce_all(index_ok))
2671
2672  message = [
2673      "SparseTensor is not right-ragged", "SparseTensor.indices =", indices
2674  ]
2675  return [control_flow_ops.Assert(sparse_indices_are_ragged_right, message)]
2676
2677
2678@ops.RegisterGradient("RaggedTensorToSparse")
2679def _ragged_tensor_to_sparse_gradient(op, unused_sparse_indices_grad,
2680                                      sparse_values_grad,
2681                                      unused_sparse_shape_grad):
2682  """Gradient for RaggedTensorToSparse."""
2683  op_inputs_nested_row_splits = op.inputs[:-1]
2684  op_inputs_flat_values = op.inputs[-1]
2685
2686  # No gradient for the RaggedTensor's nested_row_splits.
2687  nested_row_splits_gradient = [None] * len(op_inputs_nested_row_splits)
2688
2689  # Gradient for the RaggedTensor's flat_values is formed by reshaping
2690  # the gradient for the SparseTensor's values.
2691  flat_values_shape = array_ops.shape(op_inputs_flat_values)
2692  flat_values_gradient = array_ops.reshape(sparse_values_grad,
2693                                           flat_values_shape)
2694
2695  return nested_row_splits_gradient + [flat_values_gradient]
2696
2697
2698def _assert_monotonic_increasing(tensor, message=None):
2699  return check_ops.assert_non_negative(
2700      tensor[1:] - tensor[:-1], message=message)
2701
2702
2703def _assert_zero(tensor, message=None):
2704  return check_ops.assert_equal(
2705      tensor, constant_op.constant(0, dtype=tensor.dtype), message=message)
2706
2707
2708def _nrows(tensor, out_type=dtypes.int32):
2709  if isinstance(tensor, RaggedTensor):
2710    return tensor.nrows(out_type=out_type)
2711  else:
2712    return array_ops.shape(tensor, out_type=out_type)[0]
2713
2714
2715def merge_dims(value, outer_axis, inner_axis):
2716  """Merges value[outer_axis...inner_axis] into a single dimension.
2717
2718  See `RaggedTensor.merge_dims()` for more details.  This helper differs from
2719  `RaggedTensor.merge_dims()` in that `value` may be a dense or ragged tensor.
2720
2721  Args:
2722    value: A `RaggedTensor` or `Tensor`
2723    outer_axis: `int`
2724    inner_axis: `int`
2725
2726  Returns:
2727    A flattened `RaggedTensor` or `Tensor`.
2728  """
2729  if outer_axis == inner_axis:
2730    return value
2731
2732  # Flatten outer dimensions of a RaggedTensor by just taking its values.
2733  while outer_axis == 0 and isinstance(value, RaggedTensor):
2734    value = value.values
2735    inner_axis -= 1
2736    if inner_axis == 0:
2737      return value
2738
2739  # Flatten non-Ragged tensors using tf.reshape().
2740  if not isinstance(value, RaggedTensor):
2741    if value.shape.is_fully_defined():
2742      old_shape = value.shape.as_list()
2743      new_shape = old_shape[:outer_axis] + [-1] + old_shape[inner_axis + 1:]
2744    else:
2745      old_shape = array_ops.shape(value)
2746      new_shape = array_ops.concat(
2747          [old_shape[:outer_axis], [-1], old_shape[inner_axis + 1:]], axis=0)
2748    return array_ops.reshape(value, new_shape)
2749
2750  # Handle outer_axis>1 via recursion.
2751  if outer_axis > 1:
2752    return value.with_values(
2753        merge_dims(value.values, outer_axis - 1, inner_axis - 1))
2754
2755  # At this point, we know outer_axis == 1, and value is a RaggedTensor.
2756  # So we need to flatten the values and build a corresponding splits tensor.
2757  new_values = value.values
2758  new_splits = value.row_splits
2759  for axis in range(outer_axis, inner_axis):
2760    if isinstance(new_values, RaggedTensor):
2761      # Flatten a single ragged dimension.
2762      new_splits = array_ops.gather(new_values.row_splits, new_splits)
2763      new_values = new_values.values
2764    else:
2765      # Flatten all remaining dense dimensions.
2766      shape_split = inner_axis - axis + 1
2767      if new_values.shape.is_fully_defined():
2768        old_shape = new_values.shape.as_list()
2769        new_shape = [-1] + old_shape[shape_split:]
2770        flat_size = _prod(old_shape[1:shape_split])
2771      else:
2772        old_shape = array_ops.shape(new_values)
2773        new_shape = array_ops.concat([[-1], old_shape[shape_split:]], axis=0)
2774        flat_size = math_ops.cast(
2775            math_ops.reduce_prod(old_shape[1:shape_split]), new_splits.dtype)
2776      new_values = array_ops.reshape(new_values, new_shape)
2777      new_splits = new_splits * flat_size
2778      break
2779  return RaggedTensor.from_row_splits(new_values, new_splits)
2780
2781
2782def _prod(lst):
2783  """Returns the product of the numbers in a list."""
2784  return functools.reduce(operator.mul, lst, 1)
2785
2786
2787def _get_row_partition_type_tensor_pairs_tail(partition):
2788  """Gets a row partition type tensor pair for the tail.
2789
2790  If value_rowid is defined, then it is used. Otherwise, row_splits
2791  are used.
2792
2793  Args:
2794    partition: a RowPartition.
2795
2796  Returns:
2797    A list of (row_partition_type, row_partition_tensor) pairs.
2798  """
2799  if partition.has_precomputed_value_rowids():
2800    return ("VALUE_ROWIDS", partition.value_rowids())
2801  else:
2802    return ("ROW_SPLITS", partition.row_splits())
2803
2804
2805def _get_row_partition_type_tensor_pairs(rt_input):
2806  """Gets a list of the row partitions for rt_input.
2807
2808  If value_rowids are defined, then they are used. Otherwise, row_splits
2809  are used. If the outermost level has value_rowids defind, then nrows is
2810  also added.
2811
2812  Args:
2813    rt_input: a ragged tensor.
2814
2815  Returns:
2816    A list of (row_partition_type, row_partition_tensor) pairs.
2817  """
2818  partitions = rt_input._nested_row_partitions  # pylint: disable=protected-access
2819  tail = [_get_row_partition_type_tensor_pairs_tail(x) for x in partitions[1:]]
2820
2821  if partitions[0]._value_rowids is not None:  # pylint: disable=protected-access
2822    return [("FIRST_DIM_SIZE", partitions[0].nrows()),
2823            ("VALUE_ROWIDS", partitions[0].value_rowids())] + tail
2824  else:
2825    return [("ROW_SPLITS", partitions[0].row_splits())] + tail
2826
2827
2828def _shape_as_tensor(shape, dtype):
2829  """Takes shape and coerces it to a shape as a tensor.
2830
2831  If the object is already a tensor, simply passes it on (result is guaranteed
2832  to be int64 or int32, but not necessarily dtype).
2833  If not, creates a tensor of type dtype.
2834
2835  Result is either a scalar equal to -1 if the shape is unknown_rank.
2836  Otherwise, it is a vector, where unknown dimensions are represented with a
2837  value of -1.
2838
2839  In C++, see TensorShapeFromTensor for parsing shapes in kernels, and
2840  InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape, for
2841  use in the shape inference function.
2842
2843  Args:
2844    shape: input to coerce from TensorShape, Tensor, None, List[Optional[Int]],
2845      Tuple[Optional[Int]].
2846    dtype: tf.int64 or tf.int32
2847
2848  Returns:
2849    a scalar or vector tensor of dtype tf.int32 or tf.int64.
2850  """
2851  if dtype != dtypes.int64 and dtype != dtypes.int32:
2852    raise ValueError("Expected int64 or int32 for dtype: got {}".format(dtype))
2853
2854  if isinstance(shape, ops.Tensor):
2855    if shape.dtype != dtypes.int64 and shape.dtype != dtypes.int32:
2856      return math_ops.cast(shape, dtype)
2857    return shape
2858  shape = tensor_shape.as_shape(shape)
2859  if not shape:
2860    # Imply rank is unknown using a -1 scalar.
2861    return constant_op.constant(-1, dtype=dtype)
2862  shape = [(-1 if x is None else x) for x in shape.as_list()]
2863  # At this point, shape is List[Int].
2864  return constant_op.constant(shape, dtype=dtype)
2865
2866
2867def _nvals_uniform_row_length(values, uniform_row_length):
2868  """Get the number of values for uniform row length constructor."""
2869  const_nvals = tensor_shape.dimension_at_index(values.shape, 0).value
2870  if const_nvals is not None:
2871    nvals = constant_op.constant(const_nvals, uniform_row_length.dtype)
2872  elif isinstance(values, RaggedTensor):
2873    nvals = values.nrows(out_type=uniform_row_length.dtype)
2874  else:
2875    nvals = array_ops.shape(values, out_type=uniform_row_length.dtype)[0]
2876  return nvals
2877
2878
2879def _get_optional_partition_dtype(values):
2880  """Returns the partition dtype, or None if None exists."""
2881  if isinstance(values, RaggedTensor):
2882    # pylint: disable=protected-access
2883    return values._row_partition.dtype
2884  return None
2885
2886
2887_SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor)
2888
2889
2890# TODO(edloper): Consider whether we should change the registry to be on
2891# TypeSpecs rather than ValueTypes.
2892def _add_supported_value_type(cls):
2893  """Register the `cls` as supported value type of RaggedTenosr.
2894
2895  The cls must be a subclass of CompositeTensor, and must support:
2896   - Properties:
2897     - x.shape
2898     - x.dtype
2899   - Methods:
2900     - x.__getitem__(idx) (method: returns a supported value type)
2901   - Ops:
2902     - tf.shape(x) -- tf.shape(x)[0] must be a tf.Tensor.
2903     - tf.tile(x)
2904     - assert_rank_at_least(x)
2905     - tf.ones_like(x)
2906     - tf.gather(params=x, indices=Tensor)
2907     - tf.add(x, y)
2908     - tf.boolean_mask(x, ...)
2909     - @TODO(edloper): Complete this list
2910
2911   Note: the following RaggedTensor, RaggedTensorSpec methods & ops are not
2912   currently supported unless `rt.values` is a RaggedTensor or a tf.Tensor:
2913     - rt.to_tensor()
2914     - rt.to_sparse_tensor()
2915     - rt._to_variant()
2916     - rt._from_variant()
2917     - tf.ragged.cross([rt])
2918     - tf.gather(params=x, indices=rt)  # rt used for indices
2919     - RaggedTensorSpec methods:
2920       - _batch
2921       - _unbatch
2922       - _to_tensor_list
2923       - _to_batched_tensor_list
2924       - _from_compatible_tensor_list
2925
2926  Args:
2927    cls: The type to be added to supported value types.
2928  """
2929  if not issubclass(cls, composite_tensor.CompositeTensor):
2930    raise ValueError("cls(%s) must be a subclass of CompositeTensor" % cls)
2931  if not hasattr(cls, "shape"):
2932    raise ValueError("cls must support the `shape` property")
2933  if not hasattr(cls, "dtype"):
2934    raise ValueError("cls must support the `dtype` property")
2935  global _SUPPORTED_RAGGED_VALUE_TYPES
2936  _SUPPORTED_RAGGED_VALUE_TYPES += (cls,)
2937
2938
2939def _is_supported_ragged_values_type(value):
2940  return isinstance(value, _SUPPORTED_RAGGED_VALUE_TYPES)
2941
2942
2943def _assert_is_supported_ragged_values_type(value):
2944  if not _is_supported_ragged_values_type(value):
2945    ok_types = ", ".join(cls.__name__ for cls in
2946                         _SUPPORTED_RAGGED_VALUE_TYPES)
2947    raise TypeError("type(values) must be one of: %r, got %r" %
2948                    (ok_types, value))
2949