1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shapes & broadcasting for RaggedTensors."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.ragged import ragged_array_ops
30from tensorflow.python.ops.ragged import ragged_conversion_ops
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.ops.ragged import ragged_util
35class RaggedTensorDynamicShape(object):
36  """A collection of tensors encoding the shape of a potentially ragged tensor.
38  Each `RaggedTensorDynamicShape` consists of an ordered list of dimension
39  sizes.  There are two dimension types:
41    * "Uniform dimensions" are dimenisons where all slices have the same
42      length.  `RaggedTensorDynamicShape` records the size of each uniform
43      dimension using a single scalar integer.
45    * "Ragged dimensions" are dimensions whose slices may have different
46      lengths.  `RaggedTensorDynamicShape` records the size of each ragged
47      dimension using an integer vector containing the slice lengths for all
48      the slices across that dimension.
50  Furthermore, there are two ways a dimension might be encoded:
52    * "Partitioned dimensions" are dimensions that are encoded using a
53      `RaggedTensor`'s `nested_row_splits`.  The outermostmost partitioned
54      dimension must be uniform, and the innermost partitioned dimension must
55      be ragged.
57    * "Inner dimensions" are dimensions that are encoded using a
58      `RaggedTensor`'s `flat_values`.  Inner dimensions are always uniform.
60  The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes`
61  and `inner_dim_sizes`:
63    * `paritioned_dim_sizes` is a list of tensors (one for each partitioned
64      dimension).
66      * For uniform dimensions, the tensor is an integer scalar specifying the
67        size of all slices across that dimension.
68      * For ragged dimensions, the tensor is an integer vector specifying the
69        size of each slice across that dimension.
71    * `inner_dim_sizes` is a single integer vector, where each element
72      specifies the size of a single inner dimension.
74  Examples:
76  Tensor                         | Ragged | Partitioned Dim Sizes  | Inner Dim
77                                 : Rank   :                        : Sizes
78  ------------------------------ | ------ | ---------------------- | ----------
79  `[[1, 2, 3], [4, 5, 6]]`       |      0 |                        | `2, 3`
80  `[[1, 2], [], [3, 4, 5]]`      |      1 | `3, (2, 0, 3)`         |
81  `[[[1, 2], [3, 4]], [[5, 6]]]` |      1 | `2, (2, 1)`            | 2
82  `[[[1, 2], [3]], [[4, 5]]]`    |      2 | `2, (2, 1), (2, 1, 2)` |
83  """
85  def __init__(self, partitioned_dim_sizes, inner_dim_sizes):
86    """Creates a RaggedTensorDynamicShape.
88    Args:
89      partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for
90        each partitioned dimension.  If dimension `d` is uniform, then
91        `partitioned_dim_sizes[d]` must be an integer scalar, specifying the
92        size of all slices across dimension `d`.  If dimension `d` is ragged,
93        then `partitioned_dim_sizes[d]` must be an integer vector, specifying
94        the size of each slice across dimension `d`.
95      inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the
96        number of inner dimensions.  `inner_dim_sizes[n]` is the size of all
97        slices across the `n`th inner dimension (which is the
98        `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
99    """
100    assert isinstance(partitioned_dim_sizes, (list, tuple))
101    with ops.name_scope(None, 'RaggedTensorDynamicShape',
102                        (partitioned_dim_sizes, inner_dim_sizes)):
103      partitioned_dim_sizes = tuple(
104          ragged_util.convert_to_int_tensor(
105              size, dtype=dtypes.int64, name='partitioned_dimension_size')
106          for size in partitioned_dim_sizes)
107      inner_dim_sizes = ragged_util.convert_to_int_tensor(
108          inner_dim_sizes, dtype=dtypes.int64, name='inner_dim_sizes')
110      # Validate shapes.
111      if partitioned_dim_sizes:
112        for axis, dimension_size in enumerate(partitioned_dim_sizes):
113          if dimension_size.shape.ndims is None:
114            raise ValueError(
115                'rank of partitioned_dim_sizes[%d] is unknown' % axis)
116          dimension_size.shape.with_rank_at_most(1)
117        if partitioned_dim_sizes[0].shape.ndims == 1:
118          raise ValueError('outermost partitioned dimension must be uniform')
119        if partitioned_dim_sizes[-1].shape.ndims == 0:
120          raise ValueError('innermost partitioned dimension must be ragged')
121      inner_dim_sizes.shape.assert_has_rank(1)
123      self._partitioned_dim_sizes = partitioned_dim_sizes
124      self._inner_dim_sizes = inner_dim_sizes
126  def __repr__(self):
127    return ('RaggedTensorDynamicShape'
128            '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' %
129            (self._partitioned_dim_sizes, self._inner_dim_sizes))
131  @staticmethod
132  def from_dim_sizes(dim_sizes):
133    """Constructs a ragged shape from a list of dimension sizes.
135    This list contains a single tensor for each dimension, where the tensor
136    is a scalar if the dimension is uniform, or a vector if the dimension is
137    ragged.
139    Args:
140      dim_sizes: List of int64 scalars or vectors.
142    Returns:
143      A RaggedTensorDynamicShape.
144    """
145    with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes',
146                        [dim_sizes]):
147      dim_sizes = tuple(
148          ragged_util.convert_to_int_tensor(
149              size, dtype=dtypes.int64, name='dim_sizes') for size in dim_sizes)
150      # Split the dimensions into partitioned & inner dimensions.
151      inner_split = 0
152      for dim, dim_size in enumerate(dim_sizes):
153        if dim_size.shape.ndims == 1:
154          inner_split = dim + 1
155        elif dim_size.shape.ndims != 0:
156          raise ValueError('Each dim_size must be a scalar or a vector')
157      return RaggedTensorDynamicShape(dim_sizes[:inner_split],
158                                      dim_sizes[inner_split:])
160  @classmethod
161  def from_tensor(cls, rt_input):
162    """Constructs a ragged shape for a potentially ragged tensor."""
163    with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
164      rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
165      if not ragged_tensor.is_ragged(rt_input):
166        return cls([], array_ops.shape(rt_input))
167      else:
168        partitioned_dim_sizes = (
169            (rt_input.nrows(),) + rt_input.nested_row_lengths())
170        return RaggedTensorDynamicShape(
171            partitioned_dim_sizes,
172            array_ops.shape(rt_input.flat_values)[1:])
174  def dimension_size(self, axis):
175    """Returns the size of slices across the specified dimension."""
176    if not isinstance(axis, int):
177      raise TypeError('axis must be an integer')
178    partitioned_ndims = len(self._partitioned_dim_sizes)
179    if axis < partitioned_ndims:
180      return self._partitioned_dim_sizes[axis]
181    else:
182      return self._inner_dim_sizes[axis - partitioned_ndims]
184  def is_ragged(self, axis):
185    """Returns true if the indicated dimension is ragged."""
186    if not isinstance(axis, int):
187      raise TypeError('axis must be an integer')
188    rank = self.rank
189    if axis < 0:
190      raise ValueError('Negative axis values are not supported')
191    elif rank is not None and axis >= rank:
192      raise ValueError('Expected axis=%s < rank=%s' % (axis, rank))
193    else:
194      return (axis > 0 and axis < len(self._partitioned_dim_sizes) and
195              self._partitioned_dim_sizes[axis].shape.ndims == 1)
197  @property
198  def rank(self):
199    """The number of dimensions in this shape, or None if unknown."""
200    inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
201    if inner_ndims is None:
202      return None
203    else:
204      return len(self._partitioned_dim_sizes) + inner_ndims
206  @property
207  def partitioned_dim_sizes(self):
208    """The partitioned dimension sizes for this shape.
210    Returns:
211      A `list` of 0-D or 1-D integer `Tensor`.
212    """
213    return self._partitioned_dim_sizes
215  @property
216  def inner_dim_sizes(self):
217    """The inner dimension sizes for this shape.
219    Returns:
220      A 1-D integer `Tensor`.
221    """
222    return self._inner_dim_sizes
224  @property
225  def num_partitioned_dimensions(self):
226    """The number of partitioned dimensions in this shape."""
227    return len(self._partitioned_dim_sizes)
229  @property
230  def num_inner_dimensions(self):
231    """The number of inner dimensions, or `None` if not statically known."""
232    return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
234  def broadcast_to_rank(self, rank):
235    """Adds leading size-1 dimensions to broadcast `self` to the given rank.
237    E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)`
238    is `[1, 1, 3, (D2), 4]`.
240    Args:
241      rank: The rank for the returned shape.
243    Returns:
244      A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions
245      have the same size as `self` and whose outer dimensions have size `1`.
247    Raises:
248      ValueError: If `self.rank` is unknown or greater than `rank`.
249    """
250    if self.rank is None:
251      raise ValueError('Unable to broadcast: self.rank is unknown')
252    dims_to_add = rank - self.rank
253    if dims_to_add < 0:
254      raise ValueError('Unable to broadcast: rank=%d must be greater than '
255                       'self.rank=%d.' % (rank, self.rank))
256    elif dims_to_add == 0:
257      return self
258    elif self._partitioned_dim_sizes:
259      partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes
260      return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes)
261    else:
262      inner_dims = array_ops.concat(
263          [array_ops.ones([dims_to_add], dtypes.int64), self.inner_dim_sizes],
264          axis=0)
265      return RaggedTensorDynamicShape([], inner_dims)
267  def broadcast_dimension(self, axis, lengths):
268    """Returns a shape that is broadcast-compatible with self & lengths.
270    * If dimension[axis] is uniform and lengths is a scalar, the check
271      that either lengths==1 or axis==1 or lengths==axis, and tile
272      dimension[axis] with tf.where(lengths==axis, 1, axis) repeats.
274    * If dimension[axis] is uniform and lengths is a vector, then check
275      that dimension[axis]==1, and raggedly tile dimension[axis] with
276      lengths repeats.  (we can skip tiling if we statically know that
277      slice_lengths == 1??)
279    * If dimension[axis] is ragged and lengths is a scalar, then check
280      that lengths==1.
282    * If dimension[axis] is ragged and lengths is a vector, then check
283      that self.dimension_size(axis) == lengths.
285    Args:
286      axis: `int`.  The dimension to broadcast.
287      lengths: 0-D or 1-D integer `Tensor`.
289    Returns:
290      A `RaggedTensorDynamicShape`.
291    """
292    lengths = ragged_util.convert_to_int_tensor(
293        lengths, name='lengths', dtype=dtypes.int64)
294    # Check whether lengths is a scalar (for uniform dimensions) or
295    # vector (for ragged dimensions).
296    if lengths.shape.ndims is None:
297      raise ValueError('lengths must have a known rank.')
298    elif lengths.shape.ndims > 1:
299      raise ValueError('lengths must be a scalar or vector')
300    else:
301      lengths_is_scalar = (lengths.shape.ndims == 0)
303    # Verify that the shapes are compatible.
304    if self.is_ragged(axis):
305      if lengths_is_scalar:
306        condition = math_ops.equal(lengths, 1)
307      else:
308        condition = math_ops.reduce_all(
309            math_ops.equal(lengths, self.dimension_size(axis)))
310    else:
311      axis_dim_size = self.dimension_size(axis)
312      if lengths_is_scalar:
313        condition = (
314            math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1)
315            | math_ops.equal(axis_dim_size, lengths))
316      else:
317        condition = math_ops.equal(axis_dim_size, 1)
318    broadcast_err = [
319        'Unable to broadcast: dimension size mismatch in dimension', axis,
320        'lengths=', lengths, 'dim_size=',
321        self.dimension_size(axis)
322    ]
323    broadcast_check = control_flow_ops.Assert(
324        condition, data=broadcast_err, summarize=10)
326    with ops.control_dependencies([broadcast_check]):
327      # Partitioned dimensions:
328      if axis < self.num_partitioned_dimensions:
329        if self.is_ragged(axis):
330          # Use an identity op to make sure the check actually gets run.
331          return RaggedTensorDynamicShape(
332              self._partitioned_dim_sizes,
333              array_ops.identity(self.inner_dim_sizes))
334        else:
335          return self._broadcast_uniform_partitioned_dimension(axis, lengths)
337      # Inner dimensions:
338      else:
339        if lengths_is_scalar:
340          return self._broadcast_inner_dimension_to_uniform(axis, lengths)
341        else:
342          if axis == 0:
343            raise ValueError('Unable to broadcast: '
344                             'outermost dimension must be uniform.')
345          return self._broadcast_inner_dimension_to_ragged(axis, lengths)
347  def num_slices_in_dimension(self, axis):
348    """Returns the total number of slices across the indicated dimension."""
349    if axis < 0:
350      return constant_op.constant(1, dtype=dtypes.int64)
351    elif self.is_ragged(axis):
352      return math_ops.reduce_sum(self._partitioned_dim_sizes[axis])
353    else:
354      return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1)
356  def _broadcast_uniform_partitioned_dimension(self, axis, lengths):
357    """Broadcasts the partitioned dimension `axis` to match `lengths`."""
358    axis_dim_size = self.dimension_size(axis)
359    partitioned_sizes = list(self._partitioned_dim_sizes[:axis])
361    if lengths.shape.ndims == 0:
362      lengths = array_ops.where(
363          math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size)
364      repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1)
365      splits = array_ops.stack([0, self.num_slices_in_dimension(axis)])
366    else:
367      splits = math_ops.range(
368          array_ops.size(lengths, out_type=dtypes.int64) + 1)
369      repeats = lengths
371    partitioned_sizes.append(lengths)
373    for dim_size in self._partitioned_dim_sizes[axis + 1:]:
374      if dim_size.shape.ndims == 0:
375        partitioned_sizes.append(dim_size)
376        splits *= dim_size
377      else:
378        partitioned_sizes.append(
379            ragged_util.repeat_ranges(dim_size, splits, repeats))
380        splits = array_ops.gather(
381            ragged_util.lengths_to_splits(dim_size), splits)
382    inner_sizes = self._inner_dim_sizes
383    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
385  def _broadcast_inner_dimension_to_uniform(self, axis, length):
386    """Broadcasts the inner dimension `axis` to match `lengths`."""
387    dim_size = self.dimension_size(axis)
388    axis_in_inner_dims = axis - self.num_partitioned_dimensions
389    partitioned_sizes = self._partitioned_dim_sizes
390    inner_sizes = array_ops.concat([
391        self._inner_dim_sizes[:axis_in_inner_dims],
392        [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)],
393        self._inner_dim_sizes[axis_in_inner_dims + 1:]
394    ],
395                                   axis=0)
396    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
398  def _broadcast_inner_dimension_to_ragged(self, axis, lengths):
399    axis_in_inner_dims = axis - self.num_partitioned_dimensions
400    partitioned_sizes = (
401        self._partitioned_dim_sizes + tuple([
402            self._inner_dim_sizes[i] for i in range(axis_in_inner_dims)
403        ]) + (lengths,))
404    inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:]
405    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
408def broadcast_dynamic_shape(shape_x, shape_y):
409  """Returns the shape formed by broadcasting two shapes to be compatible.
411  Args:
412    shape_x: A `RaggedTensorDynamicShape`
413    shape_y: A `RaggedTensorDynamicShape`
415  Returns:
416    A `RaggedTensorDynamicShape`.
417  Raises:
418    ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
419  """
420  if not isinstance(shape_x, RaggedTensorDynamicShape):
421    raise TypeError('shape_x must be a RaggedTensorDynamicShape')
422  if not isinstance(shape_y, RaggedTensorDynamicShape):
423    raise TypeError('shape_y must be a RaggedTensorDynamicShape')
425  # Broadcast both shapes to have the same rank.
426  if shape_x.rank is None or shape_y.rank is None:
427    raise ValueError('Unable to broadcast: unknown rank')
428  broadcast_rank = max(shape_x.rank, shape_y.rank)
429  shape_x = shape_x.broadcast_to_rank(broadcast_rank)
430  shape_y = shape_y.broadcast_to_rank(broadcast_rank)
432  # Broadcast dimensions one at a time, starting from the outermost dimension.
433  for axis in range(broadcast_rank):
434    shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis))
435    shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis))
437  return shape_x
440def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True):
441  """Broadcasts a potentially ragged tensor to a ragged shape.
443  Tiles `rt_input` as necessary to match the given shape.
445  Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`.
447  Args:
448    rt_input: The potentially ragged tensor to broadcast.
449    shape: A `RaggedTensorDynamicShape`
450    broadcast_inner_dimensions: If false, then inner dimensions will not be
451      tiled.
453  Returns:
454    A potentially ragged tensor whose values are taken from
455    `rt_input`, and whose shape matches `shape`.
456  """
457  if not isinstance(shape, RaggedTensorDynamicShape):
458    raise TypeError('shape must be a RaggedTensorDynamicShape')
459  rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
461  # Broadcasting to a uniform shape.
462  if shape.num_partitioned_dimensions == 0:
463    return _broadcast_to_uniform_shape(rt_input, shape,
464                                       broadcast_inner_dimensions)
465  else:
466    return _broadcast_to_ragged_shape(rt_input, shape,
467                                      broadcast_inner_dimensions)
470def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions):
471  """Broadcasts rt_input to the uniform shape `shape`."""
472  if isinstance(rt_input, ragged_tensor.RaggedTensor):
473    raise ValueError('Incompatible with shape: ragged rank mismatch')
474  if broadcast_inner_dimensions:
475    return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes)
476  else:
477    return rt_input
480def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
481  """Broadcasts rt_input to the ragged shape `dst_shape`."""
482  # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
483  if rt_input.shape.ndims is None or dst_shape.rank is None:
484    raise ValueError('Unable to broadcast: unknown rank')
485  if rt_input.shape.ndims > dst_shape.rank:
486    raise ValueError('Incompatible with shape: rank mismatch')
487  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
488      rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
489    raise ValueError('Incompatible with shape: ragged rank mismatch')
491  src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
492  src_shape = src_shape.broadcast_to_rank(dst_shape.rank)
494  # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
495  if dst_shape.rank > rt_input.shape.ndims:
496    if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
497      rt_input = array_ops.reshape(
498          rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
499    for _ in range(dst_shape.rank - rt_input.shape.ndims):
500      if ragged_tensor.is_ragged(rt_input):
501        nrows = rt_input.nrows()
502      else:
503        nrows = array_ops.shape(rt_input, out_type=dtypes.int64)[0]
504      rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows])
506  # Add ragged dimensions to match dst_shape.
507  if ragged_tensor.is_ragged(rt_input):
508    inner_rank_diff = (
509        rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
510    if inner_rank_diff > 0:
511      rt_input = rt_input.with_flat_values(
512          ragged_conversion_ops.from_tensor(
513              rt_input.flat_values, ragged_rank=inner_rank_diff))
514  else:
515    rt_input = ragged_conversion_ops.from_tensor(
516        rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1)
518  # Do broadcasting for any dimensions that will remain uniform.  We can do
519  # these all at once, since they're independent of one another.
520  multiples = [1] * dst_shape.rank
521  for axis in range(dst_shape.num_partitioned_dimensions):
522    if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
523      src_size = src_shape.dimension_size(axis)
524      dst_size = dst_shape.dimension_size(axis)
525      if ((tensor_util.constant_value(src_size) in (1, None)) and
526          (tensor_util.constant_value(dst_size) != 1)):
527        multiples[axis] = array_ops.where(
528            math_ops.equal(src_size, 1), dst_size, 1)
529  if not all(isinstance(v, int) and v == 1 for v in multiples):
530    multiples = array_ops.stack(multiples, axis=0)
531    rt_input = ragged_array_ops.tile(rt_input, multiples)
533  if broadcast_inner_dimensions:
534    rt_input = rt_input.with_flat_values(
535        array_ops.reshape(
536            rt_input.flat_values,
537            array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)))
539  # Do broadcasting for dimensions that become ragged.  We must do these from
540  # outermost to innermost.
541  for axis in range(dst_shape.num_partitioned_dimensions):
542    if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
543      dst_size = dst_shape.dimension_size(axis)
544      rt_input = _ragged_tile_axis(rt_input, axis, dst_size)
546  return rt_input
549def _ragged_tile_axis(rt_input, axis, repeats):
550  """Tile a dimension of a RaggedTensor to match a ragged shape."""
551  assert axis > 0  # Outermost dimension may not be ragged.
553  if not ragged_tensor.is_ragged(rt_input):
554    rt_input = ragged_conversion_ops.from_tensor(rt_input, ragged_rank=1)
556  if axis > 1:
557    return rt_input.with_values(
558        _ragged_tile_axis(rt_input.values, axis - 1, repeats))
559  else:
560    src_row_splits = rt_input.nested_row_splits
561    src_row_lengths = rt_input.nested_row_lengths()
562    splits = src_row_splits[0]
564    dst_row_lengths = [repeats]
565    for i in range(1, len(src_row_lengths)):
566      dst_row_lengths.append(
567          ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats))
568      splits = array_ops.gather(src_row_splits[i], splits)
569    dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits,
570                                           repeats)
571    return ragged_tensor.RaggedTensor.from_nested_row_lengths(
572        dst_values, dst_row_lengths)