1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shapes & broadcasting for RaggedTensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops.ragged import ragged_array_ops
30from tensorflow.python.ops.ragged import ragged_conversion_ops
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.ops.ragged import ragged_util
33
34
35class RaggedTensorDynamicShape(object):
36  """A collection of tensors encoding the shape of a potentially ragged tensor.
37
38  Each `RaggedTensorDynamicShape` consists of an ordered list of dimension
39  sizes.  There are two dimension types:
40
41    * "Uniform dimensions" are dimenisons where all slices have the same
42      length.  `RaggedTensorDynamicShape` records the size of each uniform
43      dimension using a single scalar integer.
44
45    * "Ragged dimensions" are dimensions whose slices may have different
46      lengths.  `RaggedTensorDynamicShape` records the size of each ragged
47      dimension using an integer vector containing the slice lengths for all
48      the slices across that dimension.
49
50  Furthermore, there are two ways a dimension might be encoded:
51
52    * "Partitioned dimensions" are dimensions that are encoded using a
53      `RaggedTensor`'s `nested_row_splits`.  The outermostmost partitioned
54      dimension must be uniform, and the innermost partitioned dimension must
55      be ragged.
56
57    * "Inner dimensions" are dimensions that are encoded using a
58      `RaggedTensor`'s `flat_values`.  Inner dimensions are always uniform.
59
60  The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes`
61  and `inner_dim_sizes`:
62
63    * `paritioned_dim_sizes` is a list of tensors (one for each partitioned
64      dimension).
65
66      * For uniform dimensions, the tensor is an integer scalar specifying the
67        size of all slices across that dimension.
68      * For ragged dimensions, the tensor is an integer vector specifying the
69        size of each slice across that dimension.
70
71    * `inner_dim_sizes` is a single integer vector, where each element
72      specifies the size of a single inner dimension.
73
74  Examples:
75
76  Tensor                         | Ragged | Partitioned Dim Sizes  | Inner Dim
77                                 : Rank   :                        : Sizes
78  ------------------------------ | ------ | ---------------------- | ----------
79  `[[1, 2, 3], [4, 5, 6]]`       |      0 |                        | `2, 3`
80  `[[1, 2], [], [3, 4, 5]]`      |      1 | `3, (2, 0, 3)`         |
81  `[[[1, 2], [3, 4]], [[5, 6]]]` |      1 | `2, (2, 1)`            | 2
82  `[[[1, 2], [3]], [[4, 5]]]`    |      2 | `2, (2, 1), (2, 1, 2)` |
83  """
84
85  def __init__(self, partitioned_dim_sizes, inner_dim_sizes):
86    """Creates a RaggedTensorDynamicShape.
87
88    Args:
89      partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for
90        each partitioned dimension.  If dimension `d` is uniform, then
91        `partitioned_dim_sizes[d]` must be an integer scalar, specifying the
92        size of all slices across dimension `d`.  If dimension `d` is ragged,
93        then `partitioned_dim_sizes[d]` must be an integer vector, specifying
94        the size of each slice across dimension `d`.
95      inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the
96        number of inner dimensions.  `inner_dim_sizes[n]` is the size of all
97        slices across the `n`th inner dimension (which is the
98        `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
99    """
100    assert isinstance(partitioned_dim_sizes, (list, tuple))
101    with ops.name_scope(None, 'RaggedTensorDynamicShape',
102                        (partitioned_dim_sizes, inner_dim_sizes)):
103      partitioned_dim_sizes = tuple(
104          ragged_util.convert_to_int_tensor(
105              size, dtype=dtypes.int64, name='partitioned_dimension_size')
106          for size in partitioned_dim_sizes)
107      inner_dim_sizes = ragged_util.convert_to_int_tensor(
108          inner_dim_sizes, dtype=dtypes.int64, name='inner_dim_sizes')
109
110      # Validate shapes.
111      if partitioned_dim_sizes:
112        for axis, dimension_size in enumerate(partitioned_dim_sizes):
113          if dimension_size.shape.ndims is None:
114            raise ValueError(
115                'rank of partitioned_dim_sizes[%d] is unknown' % axis)
116          dimension_size.shape.with_rank_at_most(1)
117        if partitioned_dim_sizes[0].shape.ndims == 1:
118          raise ValueError('outermost partitioned dimension must be uniform')
119        if partitioned_dim_sizes[-1].shape.ndims == 0:
120          raise ValueError('innermost partitioned dimension must be ragged')
121      inner_dim_sizes.shape.assert_has_rank(1)
122
123      self._partitioned_dim_sizes = partitioned_dim_sizes
124      self._inner_dim_sizes = inner_dim_sizes
125
126  def __repr__(self):
127    return ('RaggedTensorDynamicShape'
128            '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' %
129            (self._partitioned_dim_sizes, self._inner_dim_sizes))
130
131  @staticmethod
132  def from_dim_sizes(dim_sizes):
133    """Constructs a ragged shape from a list of dimension sizes.
134
135    This list contains a single tensor for each dimension, where the tensor
136    is a scalar if the dimension is uniform, or a vector if the dimension is
137    ragged.
138
139    Args:
140      dim_sizes: List of int64 scalars or vectors.
141
142    Returns:
143      A RaggedTensorDynamicShape.
144    """
145    with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes',
146                        [dim_sizes]):
147      dim_sizes = tuple(
148          ragged_util.convert_to_int_tensor(
149              size, dtype=dtypes.int64, name='dim_sizes') for size in dim_sizes)
150      # Split the dimensions into partitioned & inner dimensions.
151      inner_split = 0
152      for dim, dim_size in enumerate(dim_sizes):
153        if dim_size.shape.ndims == 1:
154          inner_split = dim + 1
155        elif dim_size.shape.ndims != 0:
156          raise ValueError('Each dim_size must be a scalar or a vector')
157      return RaggedTensorDynamicShape(dim_sizes[:inner_split],
158                                      dim_sizes[inner_split:])
159
160  @classmethod
161  def from_tensor(cls, rt_input):
162    """Constructs a ragged shape for a potentially ragged tensor."""
163    with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
164      rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
165      if not ragged_tensor.is_ragged(rt_input):
166        return cls([], array_ops.shape(rt_input))
167      else:
168        partitioned_dim_sizes = (
169            (rt_input.nrows(),) + rt_input.nested_row_lengths())
170        return RaggedTensorDynamicShape(
171            partitioned_dim_sizes,
172            array_ops.shape(rt_input.flat_values)[1:])
173
174  def dimension_size(self, axis):
175    """Returns the size of slices across the specified dimension."""
176    if not isinstance(axis, int):
177      raise TypeError('axis must be an integer')
178    partitioned_ndims = len(self._partitioned_dim_sizes)
179    if axis < partitioned_ndims:
180      return self._partitioned_dim_sizes[axis]
181    else:
182      return self._inner_dim_sizes[axis - partitioned_ndims]
183
184  def is_ragged(self, axis):
185    """Returns true if the indicated dimension is ragged."""
186    if not isinstance(axis, int):
187      raise TypeError('axis must be an integer')
188    rank = self.rank
189    if axis < 0:
190      raise ValueError('Negative axis values are not supported')
191    elif rank is not None and axis >= rank:
192      raise ValueError('Expected axis=%s < rank=%s' % (axis, rank))
193    else:
194      return (axis > 0 and axis < len(self._partitioned_dim_sizes) and
195              self._partitioned_dim_sizes[axis].shape.ndims == 1)
196
197  @property
198  def rank(self):
199    """The number of dimensions in this shape, or None if unknown."""
200    inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
201    if inner_ndims is None:
202      return None
203    else:
204      return len(self._partitioned_dim_sizes) + inner_ndims
205
206  @property
207  def partitioned_dim_sizes(self):
208    """The partitioned dimension sizes for this shape.
209
210    Returns:
211      A `list` of 0-D or 1-D integer `Tensor`.
212    """
213    return self._partitioned_dim_sizes
214
215  @property
216  def inner_dim_sizes(self):
217    """The inner dimension sizes for this shape.
218
219    Returns:
220      A 1-D integer `Tensor`.
221    """
222    return self._inner_dim_sizes
223
224  @property
225  def num_partitioned_dimensions(self):
226    """The number of partitioned dimensions in this shape."""
227    return len(self._partitioned_dim_sizes)
228
229  @property
230  def num_inner_dimensions(self):
231    """The number of inner dimensions, or `None` if not statically known."""
232    return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
233
234  def broadcast_to_rank(self, rank):
235    """Adds leading size-1 dimensions to broadcast `self` to the given rank.
236
237    E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)`
238    is `[1, 1, 3, (D2), 4]`.
239
240    Args:
241      rank: The rank for the returned shape.
242
243    Returns:
244      A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions
245      have the same size as `self` and whose outer dimensions have size `1`.
246
247    Raises:
248      ValueError: If `self.rank` is unknown or greater than `rank`.
249    """
250    if self.rank is None:
251      raise ValueError('Unable to broadcast: self.rank is unknown')
252    dims_to_add = rank - self.rank
253    if dims_to_add < 0:
254      raise ValueError('Unable to broadcast: rank=%d must be greater than '
255                       'self.rank=%d.' % (rank, self.rank))
256    elif dims_to_add == 0:
257      return self
258    elif self._partitioned_dim_sizes:
259      partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes
260      return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes)
261    else:
262      inner_dims = array_ops.concat(
263          [array_ops.ones([dims_to_add], dtypes.int64), self.inner_dim_sizes],
264          axis=0)
265      return RaggedTensorDynamicShape([], inner_dims)
266
267  def broadcast_dimension(self, axis, lengths):
268    """Returns a shape that is broadcast-compatible with self & lengths.
269
270    * If dimension[axis] is uniform and lengths is a scalar, the check
271      that either lengths==1 or axis==1 or lengths==axis, and tile
272      dimension[axis] with tf.where(lengths==axis, 1, axis) repeats.
273
274    * If dimension[axis] is uniform and lengths is a vector, then check
275      that dimension[axis]==1, and raggedly tile dimension[axis] with
276      lengths repeats.  (we can skip tiling if we statically know that
277      slice_lengths == 1??)
278
279    * If dimension[axis] is ragged and lengths is a scalar, then check
280      that lengths==1.
281
282    * If dimension[axis] is ragged and lengths is a vector, then check
283      that self.dimension_size(axis) == lengths.
284
285    Args:
286      axis: `int`.  The dimension to broadcast.
287      lengths: 0-D or 1-D integer `Tensor`.
288
289    Returns:
290      A `RaggedTensorDynamicShape`.
291    """
292    lengths = ragged_util.convert_to_int_tensor(
293        lengths, name='lengths', dtype=dtypes.int64)
294    # Check whether lengths is a scalar (for uniform dimensions) or
295    # vector (for ragged dimensions).
296    if lengths.shape.ndims is None:
297      raise ValueError('lengths must have a known rank.')
298    elif lengths.shape.ndims > 1:
299      raise ValueError('lengths must be a scalar or vector')
300    else:
301      lengths_is_scalar = (lengths.shape.ndims == 0)
302
303    # Verify that the shapes are compatible.
304    if self.is_ragged(axis):
305      if lengths_is_scalar:
306        condition = math_ops.equal(lengths, 1)
307      else:
308        condition = math_ops.reduce_all(
309            math_ops.equal(lengths, self.dimension_size(axis)))
310    else:
311      axis_dim_size = self.dimension_size(axis)
312      if lengths_is_scalar:
313        condition = (
314            math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1)
315            | math_ops.equal(axis_dim_size, lengths))
316      else:
317        condition = math_ops.equal(axis_dim_size, 1)
318    broadcast_err = [
319        'Unable to broadcast: dimension size mismatch in dimension', axis,
320        'lengths=', lengths, 'dim_size=',
321        self.dimension_size(axis)
322    ]
323    broadcast_check = control_flow_ops.Assert(
324        condition, data=broadcast_err, summarize=10)
325
326    with ops.control_dependencies([broadcast_check]):
327      # Partitioned dimensions:
328      if axis < self.num_partitioned_dimensions:
329        if self.is_ragged(axis):
330          # Use an identity op to make sure the check actually gets run.
331          return RaggedTensorDynamicShape(
332              self._partitioned_dim_sizes,
333              array_ops.identity(self.inner_dim_sizes))
334        else:
335          return self._broadcast_uniform_partitioned_dimension(axis, lengths)
336
337      # Inner dimensions:
338      else:
339        if lengths_is_scalar:
340          return self._broadcast_inner_dimension_to_uniform(axis, lengths)
341        else:
342          if axis == 0:
343            raise ValueError('Unable to broadcast: '
344                             'outermost dimension must be uniform.')
345          return self._broadcast_inner_dimension_to_ragged(axis, lengths)
346
347  def num_slices_in_dimension(self, axis):
348    """Returns the total number of slices across the indicated dimension."""
349    if axis < 0:
350      return constant_op.constant(1, dtype=dtypes.int64)
351    elif self.is_ragged(axis):
352      return math_ops.reduce_sum(self._partitioned_dim_sizes[axis])
353    else:
354      return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1)
355
356  def _broadcast_uniform_partitioned_dimension(self, axis, lengths):
357    """Broadcasts the partitioned dimension `axis` to match `lengths`."""
358    axis_dim_size = self.dimension_size(axis)
359    partitioned_sizes = list(self._partitioned_dim_sizes[:axis])
360
361    if lengths.shape.ndims == 0:
362      lengths = array_ops.where(
363          math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size)
364      repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1)
365      splits = array_ops.stack([0, self.num_slices_in_dimension(axis)])
366    else:
367      splits = math_ops.range(
368          array_ops.size(lengths, out_type=dtypes.int64) + 1)
369      repeats = lengths
370
371    partitioned_sizes.append(lengths)
372
373    for dim_size in self._partitioned_dim_sizes[axis + 1:]:
374      if dim_size.shape.ndims == 0:
375        partitioned_sizes.append(dim_size)
376        splits *= dim_size
377      else:
378        partitioned_sizes.append(
379            ragged_util.repeat_ranges(dim_size, splits, repeats))
380        splits = array_ops.gather(
381            ragged_util.lengths_to_splits(dim_size), splits)
382    inner_sizes = self._inner_dim_sizes
383    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
384
385  def _broadcast_inner_dimension_to_uniform(self, axis, length):
386    """Broadcasts the inner dimension `axis` to match `lengths`."""
387    dim_size = self.dimension_size(axis)
388    axis_in_inner_dims = axis - self.num_partitioned_dimensions
389    partitioned_sizes = self._partitioned_dim_sizes
390    inner_sizes = array_ops.concat([
391        self._inner_dim_sizes[:axis_in_inner_dims],
392        [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)],
393        self._inner_dim_sizes[axis_in_inner_dims + 1:]
394    ],
395                                   axis=0)
396    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
397
398  def _broadcast_inner_dimension_to_ragged(self, axis, lengths):
399    axis_in_inner_dims = axis - self.num_partitioned_dimensions
400    partitioned_sizes = (
401        self._partitioned_dim_sizes + tuple([
402            self._inner_dim_sizes[i] for i in range(axis_in_inner_dims)
403        ]) + (lengths,))
404    inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:]
405    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
406
407
408def broadcast_dynamic_shape(shape_x, shape_y):
409  """Returns the shape formed by broadcasting two shapes to be compatible.
410
411  Args:
412    shape_x: A `RaggedTensorDynamicShape`
413    shape_y: A `RaggedTensorDynamicShape`
414
415  Returns:
416    A `RaggedTensorDynamicShape`.
417  Raises:
418    ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
419  """
420  if not isinstance(shape_x, RaggedTensorDynamicShape):
421    raise TypeError('shape_x must be a RaggedTensorDynamicShape')
422  if not isinstance(shape_y, RaggedTensorDynamicShape):
423    raise TypeError('shape_y must be a RaggedTensorDynamicShape')
424
425  # Broadcast both shapes to have the same rank.
426  if shape_x.rank is None or shape_y.rank is None:
427    raise ValueError('Unable to broadcast: unknown rank')
428  broadcast_rank = max(shape_x.rank, shape_y.rank)
429  shape_x = shape_x.broadcast_to_rank(broadcast_rank)
430  shape_y = shape_y.broadcast_to_rank(broadcast_rank)
431
432  # Broadcast dimensions one at a time, starting from the outermost dimension.
433  for axis in range(broadcast_rank):
434    shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis))
435    shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis))
436
437  return shape_x
438
439
440def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True):
441  """Broadcasts a potentially ragged tensor to a ragged shape.
442
443  Tiles `rt_input` as necessary to match the given shape.
444
445  Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`.
446
447  Args:
448    rt_input: The potentially ragged tensor to broadcast.
449    shape: A `RaggedTensorDynamicShape`
450    broadcast_inner_dimensions: If false, then inner dimensions will not be
451      tiled.
452
453  Returns:
454    A potentially ragged tensor whose values are taken from
455    `rt_input`, and whose shape matches `shape`.
456  """
457  if not isinstance(shape, RaggedTensorDynamicShape):
458    raise TypeError('shape must be a RaggedTensorDynamicShape')
459  rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
460
461  # Broadcasting to a uniform shape.
462  if shape.num_partitioned_dimensions == 0:
463    return _broadcast_to_uniform_shape(rt_input, shape,
464                                       broadcast_inner_dimensions)
465  else:
466    return _broadcast_to_ragged_shape(rt_input, shape,
467                                      broadcast_inner_dimensions)
468
469
470def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions):
471  """Broadcasts rt_input to the uniform shape `shape`."""
472  if isinstance(rt_input, ragged_tensor.RaggedTensor):
473    raise ValueError('Incompatible with shape: ragged rank mismatch')
474  if broadcast_inner_dimensions:
475    return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes)
476  else:
477    return rt_input
478
479
480def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
481  """Broadcasts rt_input to the ragged shape `dst_shape`."""
482  # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
483  if rt_input.shape.ndims is None or dst_shape.rank is None:
484    raise ValueError('Unable to broadcast: unknown rank')
485  if rt_input.shape.ndims > dst_shape.rank:
486    raise ValueError('Incompatible with shape: rank mismatch')
487  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
488      rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
489    raise ValueError('Incompatible with shape: ragged rank mismatch')
490
491  src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
492  src_shape = src_shape.broadcast_to_rank(dst_shape.rank)
493
494  # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
495  if dst_shape.rank > rt_input.shape.ndims:
496    if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
497      rt_input = array_ops.reshape(
498          rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
499    for _ in range(dst_shape.rank - rt_input.shape.ndims):
500      if ragged_tensor.is_ragged(rt_input):
501        nrows = rt_input.nrows()
502      else:
503        nrows = array_ops.shape(rt_input, out_type=dtypes.int64)[0]
504      rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows])
505
506  # Add ragged dimensions to match dst_shape.
507  if ragged_tensor.is_ragged(rt_input):
508    inner_rank_diff = (
509        rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
510    if inner_rank_diff > 0:
511      rt_input = rt_input.with_flat_values(
512          ragged_conversion_ops.from_tensor(
513              rt_input.flat_values, ragged_rank=inner_rank_diff))
514  else:
515    rt_input = ragged_conversion_ops.from_tensor(
516        rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1)
517
518  # Do broadcasting for any dimensions that will remain uniform.  We can do
519  # these all at once, since they're independent of one another.
520  multiples = [1] * dst_shape.rank
521  for axis in range(dst_shape.num_partitioned_dimensions):
522    if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
523      src_size = src_shape.dimension_size(axis)
524      dst_size = dst_shape.dimension_size(axis)
525      if ((tensor_util.constant_value(src_size) in (1, None)) and
526          (tensor_util.constant_value(dst_size) != 1)):
527        multiples[axis] = array_ops.where(
528            math_ops.equal(src_size, 1), dst_size, 1)
529  if not all(isinstance(v, int) and v == 1 for v in multiples):
530    multiples = array_ops.stack(multiples, axis=0)
531    rt_input = ragged_array_ops.tile(rt_input, multiples)
532
533  if broadcast_inner_dimensions:
534    rt_input = rt_input.with_flat_values(
535        array_ops.reshape(
536            rt_input.flat_values,
537            array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)))
538
539  # Do broadcasting for dimensions that become ragged.  We must do these from
540  # outermost to innermost.
541  for axis in range(dst_shape.num_partitioned_dimensions):
542    if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
543      dst_size = dst_shape.dimension_size(axis)
544      rt_input = _ragged_tile_axis(rt_input, axis, dst_size)
545
546  return rt_input
547
548
549def _ragged_tile_axis(rt_input, axis, repeats):
550  """Tile a dimension of a RaggedTensor to match a ragged shape."""
551  assert axis > 0  # Outermost dimension may not be ragged.
552
553  if not ragged_tensor.is_ragged(rt_input):
554    rt_input = ragged_conversion_ops.from_tensor(rt_input, ragged_rank=1)
555
556  if axis > 1:
557    return rt_input.with_values(
558        _ragged_tile_axis(rt_input.values, axis - 1, repeats))
559  else:
560    src_row_splits = rt_input.nested_row_splits
561    src_row_lengths = rt_input.nested_row_lengths()
562    splits = src_row_splits[0]
563
564    dst_row_lengths = [repeats]
565    for i in range(1, len(src_row_lengths)):
566      dst_row_lengths.append(
567          ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats))
568      splits = array_ops.gather(src_row_splits[i], splits)
569    dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits,
570                                           repeats)
571    return ragged_tensor.RaggedTensor.from_nested_row_lengths(
572        dst_values, dst_row_lengths)
573