1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for storing ragged tensors and their values."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.client import session
22from tensorflow.python.framework import composite_tensor
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_ragged_conversion_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops.ragged import ragged_tensor_value
34from tensorflow.python.ops.ragged import ragged_util
35from tensorflow.python.ops.ragged import segment_id_ops
36from tensorflow.python.util.tf_export import tf_export
37
38# pylint: disable=protected-access
39_eval_using_default_session = ops._eval_using_default_session
40
41# pylint: enable=protected-access
42
43#===============================================================================
44# RaggedTensor
45#===============================================================================
46
47
48@tf_export("RaggedTensor")
49class RaggedTensor(composite_tensor.CompositeTensor):
50  """Represents a ragged tensor.
51
52  A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are
53  dimensions whose slices may have different lengths.  For example, the inner
54  (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged,
55  since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths.
56  Dimensions whose slices all have the same length are called *uniform
57  dimensions*.  The outermost dimension of a `RaggedTensor` is always uniform,
58  since it consists of a single slice (and so there is no possibility for
59  differing slice lengths).
60
61  The total number of dimensions in a `RaggedTensor` is called its *rank*,
62  and the number of ragged dimensions in a `RaggedTensor` is called its
63  *ragged-rank*.  A `RaggedTensor`'s ragged-rank is fixed at graph creation
64  time: it can't depend on the runtime values of `Tensor`s, and can't vary
65  dynamically for different session runs.
66
67  ### Potentially Ragged Tensors
68
69  Many ops support both `Tensor`s and `RaggedTensor`s.  The term "potentially
70  ragged tensor" may be used to refer to a tensor that might be either a
71  `Tensor` or a `RaggedTensor`.  The ragged-rank of a `Tensor` is zero.
72
73  ### Documenting RaggedTensor Shapes
74
75  When documenting the shape of a RaggedTensor, ragged dimensions can be
76  indicated by enclosing them in parentheses.  For example, the shape of
77  a 3-D `RaggedTensor` that stores the fixed-size word embedding for each
78  word in a sentence, for each sentence in a batch, could be written as
79  `[num_sentences, (num_words), embedding_size]`.  The parentheses around
80  `(num_words)` indicate that dimension is ragged, and that the length
81  of each element list in that dimension may vary for each item.
82
83  ### Component Tensors
84
85  Internally, a `RaggedTensor` consists of a concatenated list of values that
86  are partitioned into variable-length rows.  In particular, each `RaggedTensor`
87  consists of:
88
89    * A `values` tensor, which concatenates the variable-length rows into a
90      flattened list.  For example, the `values` tensor for
91      `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is `[3, 1, 4, 1, 5, 9, 2, 6]`.
92
93    * A `row_splits` vector, which indicates how those flattened values are
94      divided into rows.  In particular, the values for row `rt[i]` are stored
95      in the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
96
97  Example:
98
99  ```python
100  >>> print(tf.RaggedTensor.from_row_splits(
101  ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
102  ...     row_splits=[0, 4, 4, 7, 8, 8]))
103  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
104  ```
105
106  ### Alternative Row-Partitioning Schemes
107
108  In addition to `row_splits`, ragged tensors provide support for four other
109  row-partitioning schemes:
110
111    * `row_lengths`: a vector with shape `[nrows]`, which specifies the length
112      of each row.
113
114    * `value_rowids` and `nrows`: `value_rowids` is a vector with shape
115      `[nvals]`, corresponding one-to-one with `values`, which specifies
116      each value's row index.  In particular, the row `rt[row]` consists of the
117      values `rt.values[j]` where `value_rowids[j]==row`.  `nrows` is an
118      int64 scalar that specifies the number of rows in the `RaggedTensor`.
119      (`nrows` is used to indicate trailing empty rows.)
120
121    * `row_starts`: a vector with shape `[nrows]`, which specifies the start
122      offset of each row.  Equivalent to `row_splits[:-1]`.
123
124    * `row_limits`: a vector with shape `[nrows]`, which specifies the stop
125      offset of each row.  Equivalent to `row_splits[1:]`.
126
127  Example: The following ragged tensors are equivalent, and all represent the
128  nested list `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]`.
129
130  ```python
131  >>> values = [3, 1, 4, 1, 5, 9, 2, 6]
132  >>> rt1 = RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8])
133  >>> rt2 = RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0])
134  >>> rt3 = RaggedTensor.from_value_rowids(
135  ...     values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
136  >>> rt4 = RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
137  >>> rt5 = RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
138  ```
139
140  ### Multiple Ragged Dimensions
141
142  `RaggedTensor`s with multiple ragged dimensions can be defined by using
143  a nested `RaggedTensor` for the `values` tensor.  Each nested `RaggedTensor`
144  adds a single ragged dimension.
145
146  ```python
147  >>> inner_rt = RaggedTensor.from_row_splits(  # =rt1 from above
148  ...     values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
149  >>> outer_rt = RaggedTensor.from_row_splits(
150  ...     values=inner_rt, row_splits=[0, 3, 3, 5])
151  >>> print outer_rt.to_list()
152  [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
153  >>> print outer_rt.ragged_rank
154  2
155  ```
156
157  The factory function `RaggedTensor.from_nested_row_splits` may be used to
158  construct a `RaggedTensor` with multiple ragged dimensions directly, by
159  providing a list of `row_splits` tensors:
160
161  ```python
162  >>> RaggedTensor.from_nested_row_splits(
163  ...     flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
164  ...     nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])).to_list()
165  [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
166  ```
167
168  ### Uniform Inner Dimensions
169
170  `RaggedTensor`s with uniform inner dimensions can be defined
171  by using a multidimensional `Tensor` for `values`.
172
173  ```python
174  >>> rt = RaggedTensor.from_row_splits(values=tf.ones([5, 3]),
175  ..                                    row_splits=[0, 2, 5])
176  >>> print rt.to_list()
177  [[[1, 1, 1], [1, 1, 1]],
178   [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]
179   >>> print rt.shape
180   (2, ?, 3)
181  ```
182
183  ### RaggedTensor Shape Restrictions
184
185  The shape of a RaggedTensor is currently restricted to have the following
186  form:
187
188    * A single uniform dimension
189    * Followed by one or more ragged dimensions
190    * Followed by zero or more uniform dimensions.
191
192  This restriction follows from the fact that each nested `RaggedTensor`
193  replaces the uniform outermost dimension of its `values` with a uniform
194  dimension followed by a ragged dimension.
195  """
196
197  #=============================================================================
198  # Constructor (private)
199  #=============================================================================
200  def __init__(self,
201               values,
202               row_splits,
203               cached_row_lengths=None,
204               cached_value_rowids=None,
205               cached_nrows=None,
206               internal=False):
207    """Creates a `RaggedTensor` with a specified partitioning for `values`.
208
209    This constructor is private -- please use one of the following ops to
210    build `RaggedTensor`s:
211
212      * `tf.RaggedTensor.from_row_lengths`
213      * `tf.RaggedTensor.from_value_rowids`
214      * `tf.RaggedTensor.from_row_splits`
215      * `tf.RaggedTensor.from_row_starts`
216      * `tf.RaggedTensor.from_row_limits`
217      * `tf.RaggedTensor.from_nested_row_splits`
218      * `tf.RaggedTensor.from_nested_row_lengths`
219      * `tf.RaggedTensor.from_nested_value_rowids`
220
221    Args:
222      values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`.
223      row_splits: A 1-D int64 tensor with shape `[nrows+1]`.
224      cached_row_lengths: A 1-D int64 tensor with shape `[nrows]`
225      cached_value_rowids: A 1-D int64 tensor with shape `[nvals]`.
226      cached_nrows: A 1-D int64 scalar tensor.
227      internal: True if the constructor is being called by one of the factory
228        methods.  If false, an exception will be raised.
229
230    Raises:
231      TypeError: If a row partitioning tensor has an inappropriate dtype.
232      TypeError: If exactly one row partitioning argument was not specified.
233      ValueError: If a row partitioning tensor has an inappropriate shape.
234      ValueError: If multiple partitioning arguments are specified.
235      ValueError: If nrows is specified but value_rowids is not None.
236    """
237    if not internal:
238      raise ValueError("RaggedTensor constructor is private; please use one "
239                       "of the factory methods instead (e.g., "
240                       "RaggedTensor.from_row_lengths())")
241
242    # Validate the arguments.
243    if not isinstance(values, (RaggedTensor, ops.Tensor)):
244      raise TypeError("values must be a Tensor or RaggedTensor.")
245    if not isinstance(row_splits, ops.Tensor):
246      raise TypeError("Row-partitioning argument must be a Tensor.")
247    values.shape.with_rank_at_least(1)
248    row_splits.shape.assert_has_rank(1)
249    row_splits.set_shape([None])
250
251    self._values = values
252    self._row_splits = row_splits
253
254    # Store any cached tensors.  These are used to avoid unnecessary
255    # round-trip conversions when a RaggedTensor is constructed from
256    # lengths or rowids, and we later want those lengths/rowids back.
257    for tensor in [cached_row_lengths, cached_value_rowids, cached_nrows]:
258      if tensor is not None and not isinstance(tensor, ops.Tensor):
259        raise TypeError("Cached value must be a Tensor or None.")
260    self._cached_row_lengths = cached_row_lengths
261    self._cached_value_rowids = cached_value_rowids
262    self._cached_nrows = cached_nrows
263
264  #=============================================================================
265  # Factory Methods
266  #=============================================================================
267
268  @classmethod
269  def from_value_rowids(cls, values, value_rowids, nrows=None, name=None):
270    """Creates a `RaggedTensor` with rows partitioned by `value_rowids`.
271
272    The returned `RaggedTensor` corresponds with the python list defined by:
273
274    ```python
275    result = [[values[i] for i in range(len(values)) if value_rowids[i] == row]
276              for row in range(nrows)]
277    ```
278
279    Warning: currently, this needs to cast value_rowids to int64 before
280    converting, since `tf.bincount` only supports `int32`.
281
282    Args:
283      values: A potentially ragged tensor with shape `[nvals, ...]`.
284      value_rowids: A 1-D int64 tensor with shape `[nvals]`, which corresponds
285        one-to-one with `values`, and specifies each value's row index.  Must be
286        nonnegative, and must be sorted in ascending order.
287      nrows: An int64 scalar specifying the number of rows.  This should be
288        specified if the `RaggedTensor` may containing empty training rows. Must
289        be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty).
290        Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty).
291      name: A name prefix for the RaggedTensor (optional).
292
293    Returns:
294      A `RaggedTensor`.  `result.rank = values.rank + 1`.
295      `result.ragged_rank = values.ragged_rank + 1`.
296
297    Raises:
298      ValueError: If `nrows` is incompatible with `value_rowids`.
299
300    #### Example:
301      ```python
302      >>> print(tf.RaggedTensor.from_value_rowids(
303      ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
304      ...     value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
305      ...     nrows=5))
306      <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
307      ```
308    """
309    with ops.name_scope(name, "RaggedFromValueRowIds",
310                        [values, value_rowids, nrows]):
311      values = convert_to_tensor_or_ragged_tensor(values, name="values")
312      value_rowids = ops.convert_to_tensor(
313          value_rowids, dtypes.int64, name="value_rowids")
314      if nrows is None:
315        const_rowids = tensor_util.constant_value(value_rowids)
316        if const_rowids is None:
317          nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1
318          const_nrows = None
319        else:
320          const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0
321          nrows = ops.convert_to_tensor(const_nrows, dtypes.int64, name="nrows")
322      else:
323        nrows = ops.convert_to_tensor(nrows, dtypes.int64, "nrows")
324        const_nrows = tensor_util.constant_value(nrows)
325        if const_nrows is not None:
326          if const_nrows < 0:
327            raise ValueError("Expected nrows >= 0; got %d" % const_nrows)
328          const_rowids = tensor_util.constant_value(value_rowids)
329          if const_rowids is not None and const_rowids.size > 0:
330            if not const_nrows >= const_rowids[-1] + 1:
331              raise ValueError(
332                  "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, "
333                  "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1]))
334
335      value_rowids.shape.assert_has_rank(1)
336      nrows.shape.assert_has_rank(0)
337      values.shape[:1].assert_is_compatible_with(value_rowids.shape)
338
339      # Convert value_rowids & nrows to row_splits.
340      # Note: we don't use segment_ids_to_row_splits() here because we want
341      # to save the intermediate value `row_lengths`, so we can cache it.
342      # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the
343      # cast (Remove the warning in the docstring when we do.)
344      value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32)
345      nrows_int32 = math_ops.cast(nrows, dtypes.int32)
346      row_lengths = math_ops.bincount(
347          value_rowids_int32,
348          minlength=nrows_int32,
349          maxlength=nrows_int32,
350          dtype=dtypes.int64)
351      row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
352      if const_nrows is not None:
353        row_lengths.set_shape([const_nrows])
354        row_splits.set_shape([const_nrows + 1])
355
356      return cls(
357          values,
358          row_splits,
359          cached_row_lengths=row_lengths,
360          cached_value_rowids=value_rowids,
361          cached_nrows=nrows,
362          internal=True)
363
364  @classmethod
365  def from_row_splits(cls, values, row_splits, name=None):
366    """Creates a `RaggedTensor` with rows partitioned by `row_splits`.
367
368    The returned `RaggedTensor` corresponds with the python list defined by:
369
370    ```python
371    result = [values[row_splits[i]:row_splits[i + 1]]
372              for i in range(len(row_splits) - 1)]
373    ```
374
375    Args:
376      values: A potentially ragged tensor with shape `[nvals, ...]`.
377      row_splits: A 1-D int64 tensor with shape `[nrows+1]`.  Must not be empty,
378        and must be sorted in ascending order.  `row_splits[0]` must be zero and
379        `row_splits[-1]` must be `nvals`.
380      name: A name prefix for the RaggedTensor (optional).
381
382    Returns:
383      A `RaggedTensor`.  `result.rank = values.rank + 1`.
384      `result.ragged_rank = values.ragged_rank + 1`.
385
386    Raises:
387      ValueError: If `row_splits` is an empty list.
388
389    #### Example:
390      ```python
391      >>> print(tf.RaggedTensor.from_row_splits(
392      ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
393      ...     row_splits=[0, 4, 4, 7, 8, 8]))
394      <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
395      ```
396    """
397    if isinstance(row_splits, (list, tuple)) and not row_splits:
398      raise ValueError("row_splits tensor may not be empty.")
399    with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
400      values = convert_to_tensor_or_ragged_tensor(values, name="values")
401      row_splits = ops.convert_to_tensor(row_splits, dtypes.int64, "row_splits")
402      row_splits.shape.assert_has_rank(1)
403      return cls(values=values, row_splits=row_splits, internal=True)
404
405  @classmethod
406  def from_row_lengths(cls, values, row_lengths, name=None):
407    """Creates a `RaggedTensor` with rows partitioned by `row_lengths`.
408
409    The returned `RaggedTensor` corresponds with the python list defined by:
410
411    ```python
412    result = [[values.pop(0) for i in range(length)]
413              for length in row_lengths]
414    ```
415
416    Args:
417      values: A potentially ragged tensor with shape `[nvals, ...]`.
418      row_lengths: A 1-D int64 tensor with shape `[nrows]`.  Must be
419        nonnegative.  `sum(row_lengths)` must be `nvals`.
420      name: A name prefix for the RaggedTensor (optional).
421
422    Returns:
423      A `RaggedTensor`.  `result.rank = values.rank + 1`.
424      `result.ragged_rank = values.ragged_rank + 1`.
425
426    #### Example:
427      ```python
428      >>> print(tf.RaggedTensor.from_row_lengths(
429      ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
430      ...     row_lengths=[4, 0, 3, 1, 0]))
431      <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []])>
432      ```
433    """
434    with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]):
435      values = convert_to_tensor_or_ragged_tensor(values, name="values")
436      row_lengths = ops.convert_to_tensor(row_lengths, dtypes.int64,
437                                          "row_lengths")
438      row_lengths.shape.assert_has_rank(1)
439      row_limits = math_ops.cumsum(row_lengths)
440      row_splits = array_ops.concat([[0], row_limits], axis=0)
441      return cls(
442          values=values,
443          row_splits=row_splits,
444          cached_row_lengths=row_lengths,
445          internal=True)
446
447  @classmethod
448  def from_row_starts(cls, values, row_starts, name=None):
449    """Creates a `RaggedTensor` with rows partitioned by `row_starts`.
450
451    Equivalent to: `from_row_splits(values, concat([row_starts, nvals]))`.
452
453    Args:
454      values: A potentially ragged tensor with shape `[nvals, ...]`.
455      row_starts: A 1-D int64 tensor with shape `[nrows]`.  Must be nonnegative
456        and sorted in ascending order.  If `nrows>0`, then `row_starts[0]` must
457        be zero.
458      name: A name prefix for the RaggedTensor (optional).
459
460    Returns:
461      A `RaggedTensor`.  `result.rank = values.rank + 1`.
462      `result.ragged_rank = values.ragged_rank + 1`.
463
464    #### Example:
465      ```python
466      >>> print(tf.RaggedTensor.from_row_starts(
467      ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
468      ...     row_starts=[0, 4, 4, 7, 8]))
469      <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
470      ```
471    """
472    with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]):
473      values = convert_to_tensor_or_ragged_tensor(values, name="values")
474      row_starts = ops.convert_to_tensor(row_starts, dtypes.int64, "row_starts")
475      row_starts.shape.assert_has_rank(1)
476      nvals = array_ops.shape(values, out_type=dtypes.int64)[:1]
477      row_splits = array_ops.concat([row_starts, nvals], axis=0)
478      return cls(values=values, row_splits=row_splits, internal=True)
479
480  @classmethod
481  def from_row_limits(cls, values, row_limits, name=None):
482    """Creates a `RaggedTensor` with rows partitioned by `row_limits`.
483
484    Equivalent to: `from_row_splits(values, concat([0, row_limits]))`.
485
486    Args:
487      values: A potentially ragged tensor with shape `[nvals, ...]`.
488      row_limits: A 1-D int64 tensor with shape `[nrows]`.  Must be sorted in
489        ascending order.  If `nrows>0`, then `row_limits[-1]` must be `nvals`.
490      name: A name prefix for the RaggedTensor (optional).
491
492    Returns:
493      A `RaggedTensor`.  `result.rank = values.rank + 1`.
494      `result.ragged_rank = values.ragged_rank + 1`.
495
496    #### Example:
497      ```python
498      >>> print(tf.RaggedTensor.from_row_limits(
499      ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
500      ...     row_limits=[4, 4, 7, 8, 8]))
501      <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
502      ```
503    """
504    with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]):
505      values = convert_to_tensor_or_ragged_tensor(values, name="values")
506      row_limits = ops.convert_to_tensor(row_limits, dtypes.int64, "row_limits")
507      row_limits.shape.assert_has_rank(1)
508      zero = array_ops.zeros([1], dtypes.int64)
509      row_splits = array_ops.concat([zero, row_limits], axis=0)
510      return cls(values=values, row_splits=row_splits, internal=True)
511
512  @classmethod
513  def from_nested_value_rowids(cls,
514                               flat_values,
515                               nested_value_rowids,
516                               nested_nrows=None,
517                               name=None):
518    """Creates a `RaggedTensor` from a nested list of `value_rowids` tensors.
519
520    Equivalent to:
521
522    ```python
523    result = flat_values
524    for (rowids, nrows) in reversed(zip(nested_value_rowids, nested_nrows)):
525      result = from_value_rowids(result, rowids, nrows)
526    ```
527
528    Args:
529      flat_values: A potentially ragged tensor.
530      nested_value_rowids: A list of 1-D int64 tensors.  The `i`th tensor is
531        used as the `value_rowids` for the `i`th ragged dimension.
532      nested_nrows: A list of int64 scalars.  The `i`th scalar is used as the
533        `nrows` for the `i`th ragged dimension.
534      name: A name prefix for the RaggedTensor (optional).
535
536    Returns:
537      A `RaggedTensor` (or `flat_values` if `nested_value_rowids` is empty).
538
539    Raises:
540      ValueError: If `len(nested_values_rowids) != len(nested_nrows)`.
541    """
542    if isinstance(nested_value_rowids, ops.Tensor):
543      raise TypeError("nested_value_rowids must be a list of Tensors")
544    if nested_nrows is None:
545      nested_nrows = [None] * len(nested_value_rowids)
546    else:
547      if isinstance(nested_nrows, ops.Tensor):
548        raise TypeError("nested_nrows must be a list of Tensors")
549      if len(nested_nrows) != len(nested_value_rowids):
550        raise ValueError("nested_nrows must have the same length as "
551                         "nested_value_rowids")
552
553    with ops.name_scope(
554        name, "RaggedFromNestedValueRowIds",
555        [flat_values] + list(nested_value_rowids) + list(nested_nrows)):
556      result = flat_values
557      for value_rowids, nrows in reversed(
558          list(zip(nested_value_rowids, nested_nrows))):
559        result = cls.from_value_rowids(result, value_rowids, nrows)
560      return result
561
562  @classmethod
563  def from_nested_row_splits(cls, flat_values, nested_row_splits, name=None):
564    """Creates a `RaggedTensor` from a nested list of `row_splits` tensors.
565
566    Equivalent to:
567
568    ```python
569    result = flat_values
570    for row_splits in reversed(nested_row_splits):
571      result = from_row_splits(result, row_splits)
572    ```
573
574    Args:
575      flat_values: A potentially ragged tensor.
576      nested_row_splits: A list of 1-D int64 tensors.  The `i`th tensor is used
577        as the `row_splits` for the `i`th ragged dimension.
578      name: A name prefix for the RaggedTensor (optional).
579
580    Returns:
581      A `RaggedTensor` (or `flat_values` if `nested_row_splits` is empty).
582    """
583    if isinstance(nested_row_splits, ops.Tensor):
584      raise TypeError("nested_row_splits must be a list of Tensors")
585    with ops.name_scope(name, "RaggedFromNestedRowSplits",
586                        [flat_values] + list(nested_row_splits)):
587      result = flat_values
588      for splits in reversed(nested_row_splits):
589        result = cls.from_row_splits(result, splits)
590      return result
591
592  @classmethod
593  def from_nested_row_lengths(cls, flat_values, nested_row_lengths, name=None):
594    """Creates a `RaggedTensor` from a nested list of `row_lengths` tensors.
595
596    Equivalent to:
597
598    ```python
599    result = flat_values
600    for row_lengths in reversed(nested_row_lengths):
601      result = from_row_lengths(result, row_lengths)
602    ```
603
604    Args:
605      flat_values: A potentially ragged tensor.
606      nested_row_lengths: A list of 1-D int64 tensors.  The `i`th tensor is used
607        as the `row_lengths` for the `i`th ragged dimension.
608      name: A name prefix for the RaggedTensor (optional).
609
610    Returns:
611      A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty).
612    """
613    if isinstance(nested_row_lengths, ops.Tensor):
614      raise TypeError("nested_row_lengths must be a list of Tensors")
615    with ops.name_scope(name, "RaggedFromNestedRowlengths",
616                        [flat_values] + list(nested_row_lengths)):
617      result = flat_values
618      for lengths in reversed(nested_row_lengths):
619        result = cls.from_row_lengths(result, lengths)
620      return result
621
622  #=============================================================================
623  # Accessors
624  #=============================================================================
625
626  @property
627  def dtype(self):
628    """The `DType` of values in this tensor."""
629    return self._values.dtype
630
631  @property
632  def shape(self):
633    """The statically known shape of this ragged tensor.
634
635    Returns:
636      A `TensorShape` containing the statically known shape of this ragged
637      tensor.  Ragged dimensions have a size of `None`.
638
639    Examples:
640
641      ```python
642      >>> ragged.constant([[0], [1, 2]]).shape
643      TensorShape([Dimension(2), Dimension(None)])
644
645      >>> ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape
646      TensorShape([Dimension(2), Dimension(None), Dimension(2)
647      ```
648    """
649    nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1
650
651    values_shape = self._values.shape
652    value_shape = values_shape[1:]
653    return tensor_shape.TensorShape([nrows, None]).concatenate(value_shape)
654
655  @property
656  def ragged_rank(self):
657    """The number of ragged dimensions in this ragged tensor.
658
659    Returns:
660      A Python `int` indicating the number of ragged dimensions in this ragged
661      tensor.  The outermost dimension is not considered ragged.
662    """
663    values_is_ragged = isinstance(self._values, RaggedTensor)
664    return self._values.ragged_rank + 1 if values_is_ragged else 1
665
666  @property
667  def values(self):
668    """The concatenated rows for this ragged tensor.
669
670    `rt.values` is a potentially ragged tensor formed by flattening the two
671    outermost dimensions of `rt` into a single dimension.
672
673    `rt.values.shape = [nvals] + rt.shape[2:]` (where `nvals` is the
674    number of items in the outer two dimensions of `rt`).
675
676    `rt.ragged_rank = self.ragged_rank - 1`
677
678    Returns:
679      A potentially ragged tensor.
680
681    #### Example:
682      ```python
683      >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
684      >>> print rt.values
685      tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6])
686      ```
687    """
688    return self._values
689
690  @property
691  def row_splits(self):
692    """The row-split indices for this ragged tensor's `values`.
693
694    `rt.row_splits` specifies where the values for each row begin and end in
695    `rt.values`.  In particular, the values for row `rt[i]` are stored in
696    the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
697
698    Returns:
699      A 1-D `int64` `Tensor` with shape `[self.nrows+1]`.
700      The returned tensor is non-empty, and is sorted in ascending order.
701      `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to
702      `self.values.shape[0]`.
703
704    #### Example:
705      ```python
706      >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
707      >>> print rt.row_splits  # indices of row splits in rt.values
708      tf.Tensor([0, 4, 4, 7, 8, 8])
709      ```
710    """
711    return self._row_splits
712
713  @property
714  def flat_values(self):
715    """The innermost `values` tensor for this ragged tensor.
716
717    Concretely, if `rt.values` is a `Tensor`, then `rt.flat_values` is
718    `rt.values`; otherwise, `rt.flat_values` is `rt.values.flat_values`.
719
720    Conceptually, `flat_values` is the tensor formed by flattening the
721    outermost dimension and all of the ragged dimensions into a single
722    dimension.
723
724    `rt.flat_values.shape = [nvals] + rt.shape[rt.ragged_rank + 1:]`
725    (where `nvals` is the number of items in the flattened dimensions).
726
727    Returns:
728      A `Tensor`.
729
730    #### Example:
731
732      ```python
733      >>> rt = ragged.constant([[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
734      >>> print rt.flat_values()
735      tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6])
736      ```
737    """
738    rt_values = self.values
739    while isinstance(rt_values, RaggedTensor):
740      rt_values = rt_values.values
741    return rt_values
742
743  @property
744  def nested_row_splits(self):
745    """A tuple containing the row_splits for all ragged dimensions.
746
747    `rt.nested_row_splits` is a tuple containing the `row_splits` tensors for
748    all ragged dimensions in `rt`, ordered from outermost to innermost.  In
749    particular, `rt.nested_row_splits = (rt.row_splits,) + value_splits` where:
750
751        * `value_splits = ()` if `rt.values` is a `Tensor`.
752        * `value_splits = rt.values.nested_row_splits` otherwise.
753
754    Returns:
755      A `tuple` of 1-D `int64` `Tensor`s.
756
757    #### Example:
758
759      ```python
760      >>> rt = ragged.constant([[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
761      >>> for i, splits in enumerate(rt.nested_row_splits()):
762      ...   print('Splits for dimension %d: %s' % (i+1, splits))
763      Splits for dimension 1: [0, 1]
764      Splits for dimension 2: [0, 3, 3, 5]
765      Splits for dimension 3: [0, 4, 4, 7, 8, 8]
766      ```
767
768    """
769    rt_nested_splits = [self.row_splits]
770    rt_values = self.values
771    while isinstance(rt_values, RaggedTensor):
772      rt_nested_splits.append(rt_values.row_splits)
773      rt_values = rt_values.values
774    return tuple(rt_nested_splits)
775
776  def value_rowids(self, name=None):
777    """Returns the row indices for the `values` in this ragged tensor.
778
779    `rt.value_rowids()` corresponds one-to-one with the outermost dimension of
780    `rt.values`, and specifies the row containing each value.  In particular,
781    the row `rt[row]` consists of the values `rt.values[j]` where
782    `rt.value_rowids()[j] == row`.
783
784    Args:
785      name: A name prefix for the returned tensor (optional).
786
787    Returns:
788      A 1-D `int64` `Tensor` with shape `self.values.shape[:1]`.
789      The returned tensor is nonnegative, and is sorted in ascending order.
790
791    #### Example:
792      ```python
793      >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
794      >>> rt.values
795      tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6])
796      >>> rt.value_rowids()
797      tf.Tensor([0, 0, 0, 0, 2, 2, 2, 3])  # corresponds 1:1 with rt.values
798      ```
799    """
800    if self._cached_value_rowids is not None:
801      return self._cached_value_rowids
802
803    with ops.name_scope(name, "RaggedValueRowIds", [self]):
804      return segment_id_ops.row_splits_to_segment_ids(self.row_splits)
805
806  def nrows(self, out_type=dtypes.int64, name=None):
807    """Returns the number of rows in this ragged tensor.
808
809    I.e., the size of the outermost dimension of the tensor.
810
811    Args:
812      out_type: `dtype` for the returned tensor.
813      name: A name prefix for the returned tensor (optional).
814
815    Returns:
816      A scalar `Tensor` with dtype `out_type`.
817
818    #### Example:
819      ```python
820      >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
821      >>> rt.nrows()  # rt has 5 rows.
822      5
823      ```
824    """
825    if self._cached_nrows is not None:
826      return self._cached_nrows
827
828    with ops.name_scope(name, "RaggedNRows", [self]):
829      return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
830
831  def row_starts(self, name=None):
832    """Returns the start indices for rows in this ragged tensor.
833
834    These indices specify where the values for each row begin in
835    `self.values`.  `rt.row_starts()` is equal to `rt.row_splits[:-1]`.
836
837    Args:
838      name: A name prefix for the returned tensor (optional).
839
840    Returns:
841      A 1-D Tensor of int64 with shape `[nrows]`.
842      The returned tensor is nonnegative, and is sorted in ascending order.
843
844    #### Example:
845      ```python
846      >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
847      >>> rt.values
848      tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6])
849      >>> rt.row_starts()  # indices of row starts in rt.values
850      tf.Tensor([0, 4, 4, 7, 8])
851      ```
852    """
853    with ops.name_scope(name, "RaggedRowStarts", [self]):
854      return self.row_splits[:-1]
855
856  def row_limits(self, name=None):
857    """Returns the limit indices for rows in this ragged tensor.
858
859    These indices specify where the values for each row end in
860    `self.values`.  `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`.
861
862    Args:
863      name: A name prefix for the returned tensor (optional).
864
865    Returns:
866      A 1-D Tensor of int64 with shape `[nrows]`.
867      The returned tensor is nonnegative, and is sorted in ascending order.
868
869    #### Example:
870      ```python
871      >>> rt = ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
872      >>> rt.values
873      tf.Tensor([3, 1, 4, 1, 5, 9, 2, 6])
874      >>> rt.row_limits()  # indices of row limits in rt.values
875      tf.Tensor([4, 4, 7, 8, 8])
876      ```
877    """
878    with ops.name_scope(name, "RaggedRowLimits", [self]):
879      return self.row_splits[1:]
880
881  def row_lengths(self, axis=1, name=None):
882    """Returns the lengths of the rows in this ragged tensor.
883
884    `rt.row_lengths()[i]` indicates the number of values in the
885    `i`th row of `rt`.
886
887    Args:
888      axis: An integer constant indicating the axis whose row lengths should be
889        returned.
890      name: A name prefix for the returned tensor (optional).
891
892    Returns:
893      A potentially ragged Tensor of int64 with shape `self.shape[:axis]`.
894
895    Raises:
896      ValueError: If `axis` is out of bounds.
897
898    #### Example:
899      ```python
900      >>> rt = ragged.constant([[[3, 1, 4], [1]], [], [[5, 9], [2]], [[6]], []])
901      >>> rt.row_lengths(rt)  # lengths of rows in rt
902      tf.Tensor([2, 0, 2, 1, 0])
903      >>> rt.row_lengths(axis=2)  # lengths of axis=2 rows.
904      <tf.RaggedTensor [[3, 1], [], [2, 1], [1], []]>
905      ```
906    """
907    if self._cached_row_lengths is not None:
908      return self._cached_row_lengths
909
910    with ops.name_scope(name, "RaggedRowLengths", [self]):
911      axis = ragged_util.get_positive_axis(axis, self.shape.ndims)
912      if axis == 0:
913        return self.nrows()
914      elif axis == 1:
915        splits = self.row_splits
916        return splits[1:] - splits[:-1]
917      elif isinstance(self.values, RaggedTensor):
918        return self.with_values(self.values.row_lengths(axis - 1))
919      else:
920        shape = array_ops.shape(self.values, out_type=dtypes.int64)
921        return self.with_values(
922            array_ops.ones(shape[:axis - 1], dtypes.int64) * shape[axis - 1])
923
924  def nested_row_lengths(self, name=None):
925    """Returns a tuple containing the row_lengths for all ragged dimensions.
926
927    `rtnested_row_lengths()` is a tuple containing the `row_lengths` tensors for
928    all ragged dimensions in `rt`, ordered from outermost to innermost.
929
930    Args:
931      name: A name prefix for the returned tensors (optional).
932
933    Returns:
934      A `tuple` of 1-D `int64` `Tensors`.  The length of the tuple is equal to
935      `self.ragged_rank`.
936    """
937    with ops.name_scope(name, "RaggedNestedRowLengths", [self]):
938      rt_nested_row_lengths = []
939      rt = self
940      while isinstance(rt, RaggedTensor):
941        rt_nested_row_lengths.append(rt.row_lengths())
942        rt = rt.values
943      return tuple(rt_nested_row_lengths)
944
945  def bounding_shape(self, axis=None, name=None):
946    """Returns the tight bounding box shape for this `RaggedTensor`.
947
948    Args:
949      axis: An integer scalar or vector indicating which axes to return the
950        bounding box for.  If not specified, then the full bounding box is
951        returned.
952      name: A name prefix for the returned tensor (optional).
953
954    Returns:
955      An int64 `Tensor`.  If `axis` is not specified, then `output`
956      is a vector with `output.shape=[self.shape.ndims]`.  If `axis` is a
957      scalar, then the `output` is a scalar.  If `axis` is a vector, then
958      `output` is a vector, where `output[i]` is the bounding size for
959      dimension `axis[i]`.
960
961    #### Example:
962      ```python
963      >>> rt = ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]])
964      >>> rt.bounding_shape()
965      [5, 4]
966      ```
967    """
968    with ops.name_scope(name, "RaggedBoundingBox", [self, axis]):
969      nested_splits = self.nested_row_splits
970      rt_flat_values = self.flat_values
971
972      # Optimized special cases for when axis=0 or axis=1:
973      if isinstance(axis, int):
974        if axis == 0:
975          return array_ops.shape(nested_splits[0], out_type=dtypes.int64)[0] - 1
976        elif axis == 1:
977          return math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0)
978
979      splits_shape = array_ops.shape(self.row_splits, out_type=dtypes.int64)
980      flat_values_shape = array_ops.shape(rt_flat_values, out_type=dtypes.int64)
981
982      ragged_dimensions = array_ops.stack([splits_shape[0] - 1] + [
983          math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0)
984          for splits in nested_splits
985      ])
986      inner_dimensions = flat_values_shape[1:]
987
988      bbox = array_ops.concat([ragged_dimensions, inner_dimensions], axis=0)
989      return bbox if axis is None else array_ops.gather(bbox, axis)
990
991  #=============================================================================
992  # Transformation
993  #=============================================================================
994
995  def with_values(self, new_values):
996    """Returns a copy of `self` with `values` replaced by `new_value`.
997
998    Preserves cached row-partitioning tensors such as `self.cached_nrows` and
999    `self.cached_value_rowids` if they have values.
1000
1001    Args:
1002      new_values: Potentially ragged tensor to use as the `values` for the
1003        returned `RaggedTensor`.  Must have `rank > 0`, and must have the same
1004        number of rows as `self.values`.
1005
1006    Returns:
1007      A `RaggedTensor`.  `result.rank = 1 + new_values.rank`.
1008      `result.ragged_rank = 1 + new_values.ragged_rank`
1009    """
1010    new_values.shape.with_rank_at_least(1)
1011    self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1])
1012    return RaggedTensor(
1013        new_values,
1014        self._row_splits,
1015        self._cached_row_lengths,
1016        self._cached_value_rowids,
1017        self._cached_nrows,
1018        internal=True)
1019
1020  def with_flat_values(self, new_values):
1021    """Returns a copy of `self` with `flat_values` replaced by `new_value`.
1022
1023    Preserves cached row-partitioning tensors such as `self.cached_nrows` and
1024    `self.cached_value_rowids` if they have values.
1025
1026    Args:
1027      new_values: Potentially ragged tensor that should replace
1028      `self.flat_values`.  Must have `rank > 0`, and must have the same
1029      number of rows as `self.flat_values`.
1030
1031    Returns:
1032      A `RaggedTensor`.
1033      `result.rank = self.ragged_rank + new_values.rank`.
1034      `result.ragged_rank = self.ragged_rank + new_values.ragged_rank`.
1035    """
1036    if isinstance(self._values, ops.Tensor):
1037      return self.with_values(new_values)
1038    else:
1039      return self.with_values(self.values.with_flat_values(new_values))
1040
1041  #=============================================================================
1042  # Tensor Type Conversions
1043  #=============================================================================
1044
1045  @classmethod
1046  def from_tensor(cls,
1047                  tensor,
1048                  lengths=None,
1049                  padding=None,
1050                  ragged_rank=1,
1051                  name=None):
1052    """Converts a `tf.Tensor` into a `RaggedTensor`.
1053
1054    The set of absent/default values may be specified using a vector of lengths
1055    or a padding value (but not both).  If `lengths` is specified, then the
1056    output tensor will satisfy `output[row] = tensor[row][:lengths[row]]`. If
1057    'lengths' is a list of lists or tuple of lists, those lists will be used
1058    as nested row lengths. If `padding` is specified, then any row *suffix*
1059    consisting entirely of `padding` will be excluded from the returned
1060    `RaggedTensor`.  If neither `lengths` nor `padding` is specified, then the
1061    returned `RaggedTensor` will have no absent/default values.
1062
1063    Examples:
1064
1065    ```python
1066    >>> dt = tf.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]])
1067    >>> tf.RaggedTensor.from_tensor(dt)
1068    <tf.RaggedTensor [[5, 7, 0], [0, 3, 0], [6, 0, 0]]>
1069    >>> tf.RaggedTensor.from_tensor(dt, lengths=[1, 0, 3])
1070    <tf.RaggedTensor [[5], [], [6, 0, 0]]>
1071
1072    >>> tf.RaggedTensor.from_tensor(dt, padding=0)
1073    <tf.RaggedTensor [[5, 7], [0, 3], [6]]>
1074
1075    >>> dt = tf.constant([[[5, 0], [7, 0], [0, 0]],
1076                          [[0, 0], [3, 0], [0, 0]],
1077                          [[6, 0], [0, 0], [0, 0]]])
1078    >>> tf.RaggedTensor.from_tensor(dt, lengths=([2, 0, 3], [1, 1, 2, 0, 1]))
1079    <tf.RaggedTensor [[[5], [7]], [], [[6, 0], [], [0]]]>
1080    ```
1081
1082    Args:
1083      tensor: The `Tensor` to convert.  Must have rank `ragged_rank + 1` or
1084        higher.
1085      lengths: An optional set of row lengths, specified using a 1-D integer
1086        `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows
1087        in `tensor`).  If specified, then `output[row]` will contain
1088        `tensor[row][:lengths[row]]`.  Negative lengths are treated as zero. You
1089        may optionally pass a list or tuple of lengths to this argument, which
1090        will be used as nested row lengths to construct a ragged tensor with
1091        multiple ragged dimensions.
1092      padding: An optional padding value.  If specified, then any row suffix
1093        consisting entirely of `padding` will be excluded from the returned
1094        RaggedTensor.  `padding` is a `Tensor` with the same dtype as `tensor`
1095        and with `shape=tensor.shape[ragged_rank + 1:]`.
1096      ragged_rank: Integer specifying the ragged rank for the returned
1097        `RaggedTensor`.  Must be greater than zero.
1098      name: A name prefix for the returned tensors (optional).
1099
1100    Returns:
1101      A `RaggedTensor` with the specified `ragged_rank`.  The shape of the
1102      returned ragged tensor is compatible with the shape of `tensor`.
1103    Raises:
1104      ValueError: If both `lengths` and `padding` are specified.
1105    """
1106    if lengths is not None and padding is not None:
1107      raise ValueError("Specify lengths or padding, but not both")
1108    if not isinstance(ragged_rank, int):
1109      raise TypeError("ragged_rank expected int, got %r" % ragged_rank)
1110    if ragged_rank <= 0:
1111      raise ValueError(
1112          "ragged_rank must be greater than 0; got %s" % ragged_rank)
1113
1114    with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]):
1115      tensor = ops.convert_to_tensor(tensor, name="tensor")
1116      tensor.shape.with_rank_at_least(ragged_rank + 1)
1117      input_shape = array_ops.shape(tensor, out_type=dtypes.int64)
1118      ncols = input_shape[1]
1119
1120      # Handle ragged_rank>1 via recursion:
1121      # If the output should have multiple ragged dimensions, then first
1122      # flatten the tensor to eliminate all but the last ragged dimension,
1123      # and recursively convert that flattened tensor.  Then add on the splits
1124      # for the dimensions that we flattened out.
1125      if ragged_rank > 1:
1126        # Flatten `tensor` to eliminate all but the last ragged dimension.
1127        new_shape = array_ops.concat([
1128            constant_op.constant([-1], dtypes.int64), input_shape[ragged_rank:]
1129        ],
1130                                     axis=0)
1131        flattened = array_ops.reshape(tensor, new_shape)
1132        # Recursively convert the flattened tensor.
1133        values = cls.from_tensor(flattened, lengths, padding)
1134        # The total number of elements in each  dimension.  E.g., if
1135        # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
1136        dim_size = math_ops.cumprod(input_shape)
1137        # Construct splits tensors for the dimensions that were flattened.
1138        new_splits = [
1139            math_ops.range(0, dim_size[dim - 1] + 1) * input_shape[dim]
1140            for dim in range(1, ragged_rank)
1141        ]
1142        return cls.from_nested_row_splits(values, new_splits)
1143
1144      # If padding was specified, then use it to find row lengths.
1145      if padding is not None:
1146        padding = ops.convert_to_tensor(
1147            padding, name="padding", dtype=tensor.dtype)
1148        padding.shape.assert_is_compatible_with(tensor.shape[2:])
1149
1150        # Find places where the padding is equal to the tensor.  (This will
1151        # broadcast `padding` across the outermost 2 dimensions of `tensor`,
1152        # so `has_default_value.shape = tensor.shape`.)
1153        has_default_value = math_ops.equal(padding, tensor)
1154
1155        # If the padding isn't a scalar, then require that all values in the
1156        # padding match each item in the tensor.  After this block of code,
1157        # `has_default.shape = tensor.shape[:2]`.  (Unfortunately, we can't just
1158        # use reduce_all for both cases, becaue when you pass an empty `axis`
1159        # list to reduce_all, it reduces all axes; but we want it to reduce no
1160        # axes -- i.e., to be a no-op.)
1161        tensor_rank = array_ops.rank(tensor)
1162        reduce_axis = math_ops.range(2, tensor_rank)
1163        has_default = control_flow_ops.cond(
1164            tensor_rank > 2,
1165            lambda: math_ops.reduce_all(has_default_value, axis=reduce_axis),
1166            lambda: has_default_value)
1167        has_default.set_shape(tensor_shape.TensorShape([None, None]))
1168        has_default.set_shape(tensor.shape[:2])
1169
1170        # Use has_default it to find the length of each row: for each
1171        # non-default item in a row, calculate the length that the row needs to
1172        # have to include that item; and then take the max of those values
1173        # (across each row).
1174        has_nondefault = math_ops.logical_not(has_default)
1175        has_nondefault = math_ops.cast(has_nondefault, dtypes.int64)
1176        length_for_nondefault_value = (
1177            has_nondefault * array_ops.expand_dims(
1178                math_ops.range(1, ncols + 1), 0))
1179        lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1)
1180
1181      if lengths is not None:
1182        if isinstance(lengths,
1183                      (list, tuple)) and len(lengths) and not isinstance(
1184                          lengths[0], (int, float)):
1185          # In this case, we've been given nested row lengths. Rather than
1186          # reconstructing the tensor mask directly, we can recreate it as
1187          # a boolean RaggedTensor, then densify that and use that as the
1188          # mask to clear out the unused data in the passed tensor.
1189          tensor.shape.with_rank_at_least(len(lengths) + 1)
1190          num_tokens = math_ops.reduce_sum(lengths[-1])
1191          ones_mask = array_ops.ones([num_tokens], dtype=dtypes.bool)
1192          ragged_mask = cls.from_nested_row_lengths(ones_mask, lengths)
1193          dense_ragged_mask = ragged_mask.to_tensor(default_value=False)
1194          masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask)
1195          return cls.from_nested_row_lengths(masked_data, lengths)
1196        else:
1197          # If we have lengths (either directly supplied, or computed from
1198          # paddings), then use those to construct splits; and then use masking
1199          # to get the corresponding values.
1200          lengths = ragged_util.convert_to_int_tensor(lengths, "lengths",
1201                                                      dtypes.int64)
1202          lengths.shape.assert_has_rank(1)
1203          lengths = math_ops.minimum(lengths, ncols)
1204          lengths = math_ops.maximum(lengths, 0)
1205          limits = math_ops.cumsum(lengths)
1206          splits = array_ops.concat(
1207              [array_ops.zeros([1], dtypes.int64), limits], axis=0)
1208          mask = array_ops.sequence_mask(lengths, maxlen=ncols)
1209          values = array_ops.boolean_mask(tensor, mask)
1210          return cls.from_row_splits(values, splits)
1211
1212      # If neither padding nor lengths were specified, then create a splits
1213      # vector that contains no default values, and reshape the input tensor
1214      # to form the values for the RaggedTensor.
1215      nrows = input_shape[0]
1216      nvals = nrows * ncols
1217      splits = math_ops.range(nrows + 1) * ncols
1218      values_shape = array_ops.concat([[nvals], input_shape[2:]], axis=0)
1219      values = array_ops.reshape(tensor, values_shape)
1220      return cls.from_row_splits(values, splits)
1221
1222  def to_tensor(self, default_value=None, name=None):
1223    """Converts this `RaggedTensor` into a `tf.Tensor`.
1224
1225    Example:
1226
1227    ```python
1228    >>> rt = ragged.constant([[9, 8, 7], [], [6, 5], [4]])
1229    >>> print rt.to_tensor()
1230    [[9 8 7]
1231     [0 0 0]
1232     [6 5 0]
1233     [4 0 0]]
1234    ```
1235
1236    Args:
1237      default_value: Value to set for indices not specified in `self`. Defaults
1238        to zero.  `default_value` must be broadcastable to
1239        `self.shape[self.ragged_rank + 1:]`.
1240      name: A name prefix for the returned tensors (optional).
1241
1242    Returns:
1243      A `Tensor` with shape `ragged.bounding_shape(self)` and the
1244      values specified by the non-empty values in `self`.  Empty values are
1245      assigned `default_value`.
1246    """
1247    with ops.name_scope(name, "RaggedToTensor", [self, default_value]):
1248      if default_value is not None:
1249        default_value = ops.convert_to_tensor(
1250            default_value, name="default_value", dtype=self.dtype)
1251
1252      # If ragged_rank > 1, then recursively convert the ragged values into a
1253      # `Tensor` before we proceed.
1254      values = self.values
1255      if is_ragged(values):
1256        values = values.to_tensor(default_value)
1257
1258      # Tile the default value, if necessary.
1259      if default_value is not None:
1260        if values.shape.ndims is not None:
1261          default_value.shape.with_rank_at_most(values.shape.ndims - 1)
1262        if (values.shape.ndims is None or default_value.shape.ndims is None or
1263            values.shape.ndims != default_value.shape.ndims + 1):
1264          value_shape = array_ops.shape(values)[1:]
1265          default_value = array_ops.broadcast_to(default_value, value_shape)
1266        default_value.shape.assert_is_compatible_with(values.shape[1:])
1267
1268      # Get the expected dense shape ([nrows, ncols] + value_shape).
1269      rt_row_lengths = [self.row_splits[1:] - self.row_splits[:-1]]
1270      nrows = array_ops.shape(self.row_splits, out_type=dtypes.int64)[0] - 1
1271      ncols = math_ops.maximum(math_ops.reduce_max(rt_row_lengths), 0)
1272      values_shape = array_ops.shape(values, out_type=dtypes.int64)
1273      value_shape = values_shape[1:]
1274      nvals = values_shape[0]
1275
1276      # Build a default value if none was supplied.
1277      if default_value is None:
1278        default_value = array_ops.zeros(value_shape, dtype=values.dtype)
1279      default_value.shape.assert_is_compatible_with(values.shape[1:])
1280      default_value.set_shape(values.shape[1:])
1281
1282      # Get the row start indices, and expand to shape=[nrows, 1].
1283      starts = array_ops.expand_dims(self.row_splits[:-1], 1)
1284
1285      # Get the row limit indices, and expand to shape=[nrows, 1].
1286      limits = array_ops.expand_dims(self.row_splits[1:], 1)
1287
1288      # Get the column indices, and expand to shape=[1, ncols].
1289      columns = array_ops.expand_dims(math_ops.range(0, ncols), 0)
1290
1291      # Build a list containing the values plus the default value.  We will use
1292      # tf.gather to collect values from this list for the `Tensor` (using
1293      # nvals as the index for the default value).
1294      values_and_default = array_ops.concat(
1295          [values, array_ops.stack([default_value])], axis=0)
1296
1297      # Construct a matrix "indices" pointing into values_and_default.  I.e.,
1298      # output[r, c] = values_and_default[indices[r, c].
1299      nondefault_index = starts + columns
1300      has_value = nondefault_index < limits
1301      default_index = array_ops.fill(array_ops.stack([nrows, ncols]), nvals)
1302      indices = array_ops.where(has_value, nondefault_index, default_index)
1303
1304      # Gather the results into a `Tensor`.
1305      return array_ops.gather(values_and_default, indices)
1306
1307  @classmethod
1308  def from_sparse(cls, st_input, name=None):
1309    """Converts a 2D `tf.SparseTensor` to a `RaggedTensor`.
1310
1311    Each row of the `output` `RaggedTensor` will contain the explicit values
1312    from the same row in `st_input`.  `st_input` must be ragged-right.  If not
1313    it is not ragged-right, then an error will be generated.
1314
1315    Example:
1316
1317    ```python
1318    >>> st = SparseTensor(indices=[[0, 1], [0, 2], [0, 3], [1, 0], [3, 0]],
1319    ...                   values=[1, 2, 3, 4, 5],
1320    ...                   dense_shape=[4, 3])
1321    >>> rt.RaggedTensor.from_sparse(st).eval().tolist()
1322    [[1, 2, 3], [4], [], [5]]
1323    ```
1324
1325    Currently, only two-dimensional `SparseTensors` are supported.
1326
1327    Args:
1328      st_input: The sparse tensor to convert.  Must have rank 2.
1329      name: A name prefix for the returned tensors (optional).
1330
1331    Returns:
1332      A `RaggedTensor` with the same values as `st_input`.
1333      `output.ragged_rank = rank(st_input) - 1`.
1334      `output.shape = [st_input.dense_shape[0], None]`.
1335    Raises:
1336      ValueError: If the number of dimensions in `st_input` is not known
1337        statically, or is not two.
1338    """
1339    if not sparse_tensor.is_sparse(st_input):
1340      raise TypeError("Expected SparseTensor, got %s" % type(st_input).__name__)
1341    with ops.name_scope(name, "RaggedFromSparse", [st_input]):
1342      st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor(
1343          st_input, name="st_input")
1344
1345      if st_input.dense_shape.shape.ndims is None:
1346        static_rank_from_dense_shape = None
1347      else:
1348        static_rank_from_dense_shape = st_input.dense_shape.shape.dims[0].value
1349
1350      if st_input.indices.shape.ndims is None:
1351        static_rank_from_indices = None
1352      else:
1353        static_rank_from_indices = st_input.indices.shape.dims[1].value
1354
1355      if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2:
1356        raise ValueError("rank(st_input) must be 2")
1357
1358      with ops.control_dependencies(
1359          _assert_sparse_indices_are_ragged_right(st_input.indices)):
1360        # Treat sparse row indices as segment ids to generate a splits tensor
1361        # thta we can pair with the sparse tensor values.  (Ignore sparse column
1362        # indices.)
1363        segment_ids = st_input.indices[:, 0]
1364        num_segments = st_input.dense_shape[0]
1365        return cls.from_value_rowids(st_input.values, segment_ids, num_segments)
1366
1367  def to_sparse(self, name=None):
1368    """Converts this `RaggedTensor` into a `tf.SparseTensor`.
1369
1370    Example:
1371
1372    ```python
1373    >>> rt = ragged.constant([[1, 2, 3], [4], [], [5, 6]])
1374    >>> rt.to_sparse().eval()
1375    SparseTensorValue(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [3, 1]],
1376                      values=[1, 2, 3, 4, 5, 6],
1377                      dense_shape=[4, 3])
1378    ```
1379
1380    Args:
1381      name: A name prefix for the returned tensors (optional).
1382
1383    Returns:
1384      A SparseTensor with the same values as `self`.
1385    """
1386    with ops.name_scope(name, "RaggedToSparse", [self]):
1387      result = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
1388          self.nested_row_splits, self.flat_values, name=name)
1389      return sparse_tensor.SparseTensor(result.sparse_indices,
1390                                        result.sparse_values,
1391                                        result.sparse_dense_shape)
1392
1393  #=============================================================================
1394  # String Encoding
1395  #=============================================================================
1396  def __str__(self):
1397    if self._is_eager():
1398      return "<tf.RaggedTensor %s>" % self.to_list()
1399    else:
1400      return self.__repr__()
1401
1402  def __repr__(self):
1403    return "tf.RaggedTensor(values=%s, row_splits=%s)" % (self._values,
1404                                                          self._row_splits)
1405
1406  #=============================================================================
1407  # Eager Execution Mode
1408  #=============================================================================
1409
1410  def to_list(self):
1411    """Returns a nested Python `list` with the values for this `RaggedTensor`.
1412
1413    Requires that `rt` was constructed in eager execution mode.
1414
1415    Returns:
1416      A nested Python `list`.
1417    """
1418    if self._is_eager():
1419      return self._eager_value().to_list()
1420    else:
1421      raise ValueError("RaggedTensor.to_list() is only supported in eager "
1422                       "mode; in graph mode, evaluate the RaggedTensor first "
1423                       "and then use RaggedTensorValue.to_list().")
1424
1425  def _eager_value(self):
1426    """Returns a RaggedTensorValue for self.  Requires self._is_eager()=true."""
1427    value = self.flat_values.numpy()
1428    for row_splits in reversed(self.nested_row_splits):
1429      value = ragged_tensor_value.RaggedTensorValue(value, row_splits.numpy())
1430    return value
1431
1432  def _is_eager(self):
1433    """Returns True if values & row_splits Tensors are all `EagerTensor`s."""
1434    rt = self
1435    while isinstance(rt, RaggedTensor):
1436      if not isinstance(rt.row_splits, ops.EagerTensor):
1437        return False
1438      rt = rt.values
1439    return isinstance(rt, ops.EagerTensor)
1440
1441  #=============================================================================
1442  # Indexing & Slicing
1443  #=============================================================================
1444  def __getitem__(self, key):
1445    """Returns the specified piece of this RaggedTensor."""
1446    # See ragged_getitem.py for the documentation and implementation of this
1447    # method.
1448    #
1449    # Note: the imports in ragged/__init__.py ensure that this method always
1450    # gets overridden before it is called.
1451
1452  #=============================================================================
1453  # Name Scope
1454  #=============================================================================
1455
1456  # This private function is used by ops.name_scope to ensure that all of the
1457  # input tensors for the scope belong to the same graph.  Defining this means
1458  # that you may include `RaggedTensor` objects in the name_scope `values`
1459  # list.
1460  def _as_graph_element(self):
1461    """Convert `self` to a graph element."""
1462    values = self.values
1463    while isinstance(values, RaggedTensor):
1464      values = values.values
1465    return values
1466
1467  #=============================================================================
1468  # Composite Tensor
1469  #=============================================================================
1470
1471  def _to_components(self):
1472    return (self.flat_values,) + self.nested_row_splits
1473
1474  @classmethod
1475  def _from_components(cls, components):
1476    return cls.from_nested_row_splits(components[0], components[1:])
1477
1478  def _shape_invariant_to_components(self, shape=None):
1479    ragged_rank = self.ragged_rank
1480    flat_values = self.flat_values
1481
1482    if shape is None:
1483      # Default shape invariant
1484      value_shape = flat_values.shape[1:]
1485      values_shape = tensor_shape.TensorShape([None]).concatenate(value_shape)
1486      return ((values_shape, self._row_splits.shape) +
1487              tuple(tensor_shape.TensorShape([None])
1488                    for i in range(1, ragged_rank)))
1489    else:
1490      # Explicitly specified shape invariant
1491      if shape.ndims is not None and shape.ndims <= ragged_rank:
1492        raise ValueError("Shape invariant %s does not have sufficient rank "
1493                         "for a RaggedTensor with %d ragged dimensions." %
1494                         (shape, self.ragged_rank))
1495      if any(tensor_shape.dimension_value(shape[dim]) is not None
1496             for dim in range(1, self.ragged_rank + 1)):
1497        raise ValueError("Shape invariant dimension size must be None for "
1498                         "ragged dimenions.")
1499      nrows = tensor_shape.dimension_value(shape[0])
1500      value_shape = shape[self.ragged_rank + 1:]
1501      values_shape = tensor_shape.TensorShape([None]).concatenate(value_shape)
1502      if nrows is None:
1503        outer_splits_shape = tensor_shape.TensorShape([None])
1504      else:
1505        outer_splits_shape = tensor_shape.TensorShape([nrows + 1])
1506      return ((values_shape, outer_splits_shape) +
1507              tuple(tensor_shape.TensorShape([None])
1508                    for i in range(1, ragged_rank)))
1509
1510  @property
1511  def _is_graph_tensor(self):
1512    return hasattr(self._values, 'graph')
1513
1514
1515def is_ragged(value):
1516  """Returns true if `value` is a ragged tensor or ragged tensor value."""
1517  return isinstance(value,
1518                    (RaggedTensor, ragged_tensor_value.RaggedTensorValue))
1519
1520
1521#===============================================================================
1522# Convert value -> tensor
1523#===============================================================================
1524def convert_to_tensor_or_ragged_tensor(value,
1525                                       dtype=None,
1526                                       preferred_dtype=None,
1527                                       name=None):
1528  """Converts value to a `RaggedTensor` or `Tensor`.
1529
1530  * If `value` is a `RaggedTensor`, then return it as-is.
1531  * If `value` is a `RaggedTensorValue`, return a corresponding constant
1532    `RaggedTensor`.
1533  * Otherwise, use `convert_to_tensor` to convert `value` to a `Tensor`.
1534
1535  Args:
1536    value: A `RaggedTensor`, a `RaggedTensorValue`, or an object whose type has
1537      a registered `Tensor` conversion function.
1538    dtype: Optional element type for the returned tensor.  If missing the type
1539      is inferred from the type of `value`.
1540    preferred_dtype: Optional element type for the returned tensor, used when
1541      dtype is None.  This argument has no effect if `value` is already a
1542      tensor, or when conversion is not possible.
1543    name: Optional name to use if a new `Tensor` is created.
1544
1545  Returns:
1546    A `Tensor` or `RaggedTensor`.
1547  """
1548  if isinstance(value, RaggedTensor):
1549    if dtype and not dtype.is_compatible_with(value.dtype):
1550      raise ValueError("Tensor conversion requested dtype %s for "
1551                       "RaggedTensor with dtype %s: %r" %
1552                       (dtype.name, value.dtype.name, value))
1553    return value
1554  elif isinstance(value, ragged_tensor_value.RaggedTensorValue):
1555    with ops.name_scope(name, "ConvertToTensorOrRaggedTensor", []):
1556      flat_values = ops.convert_to_tensor(
1557          value=value.flat_values,
1558          dtype=dtype,
1559          preferred_dtype=preferred_dtype,
1560          name="flat_values")
1561      return RaggedTensor.from_nested_row_splits(flat_values,
1562                                                 value.nested_row_splits)
1563  else:
1564    return ops.convert_to_tensor(
1565        value=value, dtype=dtype, preferred_dtype=preferred_dtype, name=name)
1566
1567
1568#===============================================================================
1569# Register RaggedTensor for use with session.run.
1570#===============================================================================
1571def _ragged_tensor_value_from_components(components):
1572  components = list(components)
1573  value = components.pop()
1574  while components:
1575    value = ragged_tensor_value.RaggedTensorValue(value, components.pop())
1576  return value
1577
1578
1579def _ragged_tensor_session_fetch(rt):
1580  components = rt.nested_row_splits + (rt.flat_values,)
1581  return (components, _ragged_tensor_value_from_components)
1582
1583
1584def _ragged_tensor_session_feed(feed_key, feed_val):
1585  key_components = feed_key.nested_row_splits + (feed_key.flat_values,)
1586  val_components = feed_val.nested_row_splits + (feed_val.flat_values,)
1587  return zip(key_components, val_components)
1588
1589
1590def _ragged_tensor_session_feed_for_partial_run(feed_key):
1591  return feed_key.nested_row_splits + (feed_key.flat_values,)
1592
1593
1594session.register_session_run_conversion_functions(
1595    RaggedTensor, _ragged_tensor_session_fetch, _ragged_tensor_session_feed,
1596    _ragged_tensor_session_feed_for_partial_run)
1597
1598
1599#===============================================================================
1600# RaggedTensorType
1601#===============================================================================
1602class RaggedTensorType(object):
1603  """Encoding of a static type for a `RaggedTensor`.
1604
1605  Use this type to express/declare that an output must have the type of
1606  `RaggedTensor`.
1607  """
1608
1609  def __init__(self, dtype, ragged_rank):
1610    """Initializes a RaggedTensorType object.
1611
1612    Args:
1613      dtype: data type of the `RaggedTensor`'s inner values.
1614      ragged_rank: ragged_rank of the declared `RaggedTensor`.
1615    """
1616    self._dtype = dtype
1617    self._ragged_rank = ragged_rank
1618
1619  dtype = property(lambda self: self._dtype)
1620  ragged_rank = property(lambda self: self._ragged_rank)
1621
1622
1623#===============================================================================
1624# Helper Functions
1625#===============================================================================
1626def _assert_sparse_indices_are_ragged_right(indices):
1627  """Checks that the given SparseTensor.indices tensor is ragged-right.
1628
1629  Example: `indices = [[0, 0], [0, 1], [2, 0], [3, 1]]` is not ragged right
1630  because the entry `[3, 1]` skips a cell.
1631
1632  Args:
1633    indices: The SparseTensor indices to check.
1634
1635  Returns:
1636    A list of control dependency op tensors.
1637  """
1638  index_prefix = indices[:, :-1]
1639  index_suffix = indices[:, -1]
1640
1641  # Check whether each index is starting a new row in the innermost dimension
1642  # (prefix[i] != prefix[i-1]) or continuing a row (prefix[i] == prefix[i-1]).
1643  # (Note: this skips the first index; we will check that separately below.)
1644  index_prefix_changed = math_ops.reduce_any(
1645      math_ops.not_equal(index_prefix[1:], index_prefix[:-1]), axis=1)
1646
1647  # Check two cases:
1648  #   * For indices that start a new row: index_suffix[i] must be zero.
1649  #   * For indices that continue a row: index_suffix[i] must be equal to
1650  #     index_suffix[i-1]+1.
1651  index_ok = array_ops.where(
1652      index_prefix_changed, math_ops.equal(index_suffix[1:], 0),
1653      math_ops.equal(index_suffix[1:], index_suffix[:-1] + 1))
1654
1655  # Also check that the very first index didn't skip any cells.  The first
1656  # index starts a new row (by definition), so its suffix should be zero.
1657  sparse_indices_are_ragged_right = math_ops.logical_and(
1658      math_ops.reduce_all(math_ops.equal(index_suffix[:1], 0)),
1659      math_ops.reduce_all(index_ok))
1660
1661  message = [
1662      "SparseTensor is not right-ragged", "SparseTensor.indices =", indices
1663  ]
1664  return [control_flow_ops.Assert(sparse_indices_are_ragged_right, message)]
1665
1666
1667@ops.RegisterGradient("RaggedTensorToSparse")
1668def _ragged_tensor_to_sparse_gradient(op, unused_sparse_indices_grad,
1669                                      sparse_values_grad,
1670                                      unused_sparse_shape_grad):
1671  """Gradient for RaggedTensorToSparse."""
1672  op_inputs_nested_row_splits = op.inputs[:-1]
1673  op_inputs_flat_values = op.inputs[-1]
1674
1675  # No gradient for the RaggedTensor's nested_row_splits.
1676  nested_row_splits_gradient = [None] * len(op_inputs_nested_row_splits)
1677
1678  # Gradient for the RaggedTensor's flat_values is formed by reshaping
1679  # the gradient for the SparseTensor's values.
1680  flat_values_shape = array_ops.shape(op_inputs_flat_values)
1681  flat_values_gradient = array_ops.reshape(sparse_values_grad,
1682                                           flat_values_shape)
1683
1684  return nested_row_splits_gradient + [flat_values_gradient]
1685