1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Shapes & broadcasting for RaggedTensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops.ragged import ragged_array_ops 30from tensorflow.python.ops.ragged import ragged_conversion_ops 31from tensorflow.python.ops.ragged import ragged_tensor 32from tensorflow.python.ops.ragged import ragged_util 33 34 35class RaggedTensorDynamicShape(object): 36 """A collection of tensors encoding the shape of a potentially ragged tensor. 37 38 Each `RaggedTensorDynamicShape` consists of an ordered list of dimension 39 sizes. There are two dimension types: 40 41 * "Uniform dimensions" are dimenisons where all slices have the same 42 length. `RaggedTensorDynamicShape` records the size of each uniform 43 dimension using a single scalar integer. 44 45 * "Ragged dimensions" are dimensions whose slices may have different 46 lengths. `RaggedTensorDynamicShape` records the size of each ragged 47 dimension using an integer vector containing the slice lengths for all 48 the slices across that dimension. 49 50 Furthermore, there are two ways a dimension might be encoded: 51 52 * "Partitioned dimensions" are dimensions that are encoded using a 53 `RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned 54 dimension must be uniform, and the innermost partitioned dimension must 55 be ragged. 56 57 * "Inner dimensions" are dimensions that are encoded using a 58 `RaggedTensor`'s `flat_values`. Inner dimensions are always uniform. 59 60 The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes` 61 and `inner_dim_sizes`: 62 63 * `paritioned_dim_sizes` is a list of tensors (one for each partitioned 64 dimension). 65 66 * For uniform dimensions, the tensor is an integer scalar specifying the 67 size of all slices across that dimension. 68 * For ragged dimensions, the tensor is an integer vector specifying the 69 size of each slice across that dimension. 70 71 * `inner_dim_sizes` is a single integer vector, where each element 72 specifies the size of a single inner dimension. 73 74 Examples: 75 76 Tensor | Ragged | Partitioned Dim Sizes | Inner Dim 77 : Rank : : Sizes 78 ------------------------------ | ------ | ---------------------- | ---------- 79 `[[1, 2, 3], [4, 5, 6]]` | 0 | | `2, 3` 80 `[[1, 2], [], [3, 4, 5]]` | 1 | `3, (2, 0, 3)` | 81 `[[[1, 2], [3, 4]], [[5, 6]]]` | 1 | `2, (2, 1)` | 2 82 `[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` | 83 """ 84 85 def __init__(self, partitioned_dim_sizes, inner_dim_sizes): 86 """Creates a RaggedTensorDynamicShape. 87 88 Args: 89 partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for 90 each partitioned dimension. If dimension `d` is uniform, then 91 `partitioned_dim_sizes[d]` must be an integer scalar, specifying the 92 size of all slices across dimension `d`. If dimension `d` is ragged, 93 then `partitioned_dim_sizes[d]` must be an integer vector, specifying 94 the size of each slice across dimension `d`. 95 inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the 96 number of inner dimensions. `inner_dim_sizes[n]` is the size of all 97 slices across the `n`th inner dimension (which is the 98 `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor. 99 """ 100 assert isinstance(partitioned_dim_sizes, (list, tuple)) 101 with ops.name_scope(None, 'RaggedTensorDynamicShape', 102 (partitioned_dim_sizes, inner_dim_sizes)): 103 partitioned_dim_sizes = tuple( 104 ragged_util.convert_to_int_tensor( 105 size, dtype=dtypes.int64, name='partitioned_dimension_size') 106 for size in partitioned_dim_sizes) 107 inner_dim_sizes = ragged_util.convert_to_int_tensor( 108 inner_dim_sizes, dtype=dtypes.int64, name='inner_dim_sizes') 109 110 # Validate shapes. 111 if partitioned_dim_sizes: 112 for axis, dimension_size in enumerate(partitioned_dim_sizes): 113 if dimension_size.shape.ndims is None: 114 raise ValueError( 115 'rank of partitioned_dim_sizes[%d] is unknown' % axis) 116 dimension_size.shape.with_rank_at_most(1) 117 if partitioned_dim_sizes[0].shape.ndims == 1: 118 raise ValueError('outermost partitioned dimension must be uniform') 119 if partitioned_dim_sizes[-1].shape.ndims == 0: 120 raise ValueError('innermost partitioned dimension must be ragged') 121 inner_dim_sizes.shape.assert_has_rank(1) 122 123 self._partitioned_dim_sizes = partitioned_dim_sizes 124 self._inner_dim_sizes = inner_dim_sizes 125 126 def __repr__(self): 127 return ('RaggedTensorDynamicShape' 128 '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' % 129 (self._partitioned_dim_sizes, self._inner_dim_sizes)) 130 131 @staticmethod 132 def from_dim_sizes(dim_sizes): 133 """Constructs a ragged shape from a list of dimension sizes. 134 135 This list contains a single tensor for each dimension, where the tensor 136 is a scalar if the dimension is uniform, or a vector if the dimension is 137 ragged. 138 139 Args: 140 dim_sizes: List of int64 scalars or vectors. 141 142 Returns: 143 A RaggedTensorDynamicShape. 144 """ 145 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes', 146 [dim_sizes]): 147 dim_sizes = tuple( 148 ragged_util.convert_to_int_tensor( 149 size, dtype=dtypes.int64, name='dim_sizes') for size in dim_sizes) 150 # Split the dimensions into partitioned & inner dimensions. 151 inner_split = 0 152 for dim, dim_size in enumerate(dim_sizes): 153 if dim_size.shape.ndims == 1: 154 inner_split = dim + 1 155 elif dim_size.shape.ndims != 0: 156 raise ValueError('Each dim_size must be a scalar or a vector') 157 return RaggedTensorDynamicShape(dim_sizes[:inner_split], 158 dim_sizes[inner_split:]) 159 160 @classmethod 161 def from_tensor(cls, rt_input): 162 """Constructs a ragged shape for a potentially ragged tensor.""" 163 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]): 164 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) 165 if not ragged_tensor.is_ragged(rt_input): 166 return cls([], array_ops.shape(rt_input)) 167 else: 168 partitioned_dim_sizes = ( 169 (rt_input.nrows(),) + rt_input.nested_row_lengths()) 170 return RaggedTensorDynamicShape( 171 partitioned_dim_sizes, 172 array_ops.shape(rt_input.flat_values)[1:]) 173 174 def dimension_size(self, axis): 175 """Returns the size of slices across the specified dimension.""" 176 if not isinstance(axis, int): 177 raise TypeError('axis must be an integer') 178 partitioned_ndims = len(self._partitioned_dim_sizes) 179 if axis < partitioned_ndims: 180 return self._partitioned_dim_sizes[axis] 181 else: 182 return self._inner_dim_sizes[axis - partitioned_ndims] 183 184 def is_ragged(self, axis): 185 """Returns true if the indicated dimension is ragged.""" 186 if not isinstance(axis, int): 187 raise TypeError('axis must be an integer') 188 rank = self.rank 189 if axis < 0: 190 raise ValueError('Negative axis values are not supported') 191 elif rank is not None and axis >= rank: 192 raise ValueError('Expected axis=%s < rank=%s' % (axis, rank)) 193 else: 194 return (axis > 0 and axis < len(self._partitioned_dim_sizes) and 195 self._partitioned_dim_sizes[axis].shape.ndims == 1) 196 197 @property 198 def rank(self): 199 """The number of dimensions in this shape, or None if unknown.""" 200 inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) 201 if inner_ndims is None: 202 return None 203 else: 204 return len(self._partitioned_dim_sizes) + inner_ndims 205 206 @property 207 def partitioned_dim_sizes(self): 208 """The partitioned dimension sizes for this shape. 209 210 Returns: 211 A `list` of 0-D or 1-D integer `Tensor`. 212 """ 213 return self._partitioned_dim_sizes 214 215 @property 216 def inner_dim_sizes(self): 217 """The inner dimension sizes for this shape. 218 219 Returns: 220 A 1-D integer `Tensor`. 221 """ 222 return self._inner_dim_sizes 223 224 @property 225 def num_partitioned_dimensions(self): 226 """The number of partitioned dimensions in this shape.""" 227 return len(self._partitioned_dim_sizes) 228 229 @property 230 def num_inner_dimensions(self): 231 """The number of inner dimensions, or `None` if not statically known.""" 232 return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) 233 234 def broadcast_to_rank(self, rank): 235 """Adds leading size-1 dimensions to broadcast `self` to the given rank. 236 237 E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)` 238 is `[1, 1, 3, (D2), 4]`. 239 240 Args: 241 rank: The rank for the returned shape. 242 243 Returns: 244 A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions 245 have the same size as `self` and whose outer dimensions have size `1`. 246 247 Raises: 248 ValueError: If `self.rank` is unknown or greater than `rank`. 249 """ 250 if self.rank is None: 251 raise ValueError('Unable to broadcast: self.rank is unknown') 252 dims_to_add = rank - self.rank 253 if dims_to_add < 0: 254 raise ValueError('Unable to broadcast: rank=%d must be greater than ' 255 'self.rank=%d.' % (rank, self.rank)) 256 elif dims_to_add == 0: 257 return self 258 elif self._partitioned_dim_sizes: 259 partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes 260 return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes) 261 else: 262 inner_dims = array_ops.concat( 263 [array_ops.ones([dims_to_add], dtypes.int64), self.inner_dim_sizes], 264 axis=0) 265 return RaggedTensorDynamicShape([], inner_dims) 266 267 def broadcast_dimension(self, axis, lengths): 268 """Returns a shape that is broadcast-compatible with self & lengths. 269 270 * If dimension[axis] is uniform and lengths is a scalar, the check 271 that either lengths==1 or axis==1 or lengths==axis, and tile 272 dimension[axis] with tf.where(lengths==axis, 1, axis) repeats. 273 274 * If dimension[axis] is uniform and lengths is a vector, then check 275 that dimension[axis]==1, and raggedly tile dimension[axis] with 276 lengths repeats. (we can skip tiling if we statically know that 277 slice_lengths == 1??) 278 279 * If dimension[axis] is ragged and lengths is a scalar, then check 280 that lengths==1. 281 282 * If dimension[axis] is ragged and lengths is a vector, then check 283 that self.dimension_size(axis) == lengths. 284 285 Args: 286 axis: `int`. The dimension to broadcast. 287 lengths: 0-D or 1-D integer `Tensor`. 288 289 Returns: 290 A `RaggedTensorDynamicShape`. 291 """ 292 lengths = ragged_util.convert_to_int_tensor( 293 lengths, name='lengths', dtype=dtypes.int64) 294 # Check whether lengths is a scalar (for uniform dimensions) or 295 # vector (for ragged dimensions). 296 if lengths.shape.ndims is None: 297 raise ValueError('lengths must have a known rank.') 298 elif lengths.shape.ndims > 1: 299 raise ValueError('lengths must be a scalar or vector') 300 else: 301 lengths_is_scalar = (lengths.shape.ndims == 0) 302 303 # Verify that the shapes are compatible. 304 if self.is_ragged(axis): 305 if lengths_is_scalar: 306 condition = math_ops.equal(lengths, 1) 307 else: 308 condition = math_ops.reduce_all( 309 math_ops.equal(lengths, self.dimension_size(axis))) 310 else: 311 axis_dim_size = self.dimension_size(axis) 312 if lengths_is_scalar: 313 condition = ( 314 math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1) 315 | math_ops.equal(axis_dim_size, lengths)) 316 else: 317 condition = math_ops.equal(axis_dim_size, 1) 318 broadcast_err = [ 319 'Unable to broadcast: dimension size mismatch in dimension', axis, 320 'lengths=', lengths, 'dim_size=', 321 self.dimension_size(axis) 322 ] 323 broadcast_check = control_flow_ops.Assert( 324 condition, data=broadcast_err, summarize=10) 325 326 with ops.control_dependencies([broadcast_check]): 327 # Partitioned dimensions: 328 if axis < self.num_partitioned_dimensions: 329 if self.is_ragged(axis): 330 # Use an identity op to make sure the check actually gets run. 331 return RaggedTensorDynamicShape( 332 self._partitioned_dim_sizes, 333 array_ops.identity(self.inner_dim_sizes)) 334 else: 335 return self._broadcast_uniform_partitioned_dimension(axis, lengths) 336 337 # Inner dimensions: 338 else: 339 if lengths_is_scalar: 340 return self._broadcast_inner_dimension_to_uniform(axis, lengths) 341 else: 342 if axis == 0: 343 raise ValueError('Unable to broadcast: ' 344 'outermost dimension must be uniform.') 345 return self._broadcast_inner_dimension_to_ragged(axis, lengths) 346 347 def num_slices_in_dimension(self, axis): 348 """Returns the total number of slices across the indicated dimension.""" 349 if axis < 0: 350 return constant_op.constant(1, dtype=dtypes.int64) 351 elif self.is_ragged(axis): 352 return math_ops.reduce_sum(self._partitioned_dim_sizes[axis]) 353 else: 354 return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1) 355 356 def _broadcast_uniform_partitioned_dimension(self, axis, lengths): 357 """Broadcasts the partitioned dimension `axis` to match `lengths`.""" 358 axis_dim_size = self.dimension_size(axis) 359 partitioned_sizes = list(self._partitioned_dim_sizes[:axis]) 360 361 if lengths.shape.ndims == 0: 362 lengths = array_ops.where( 363 math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size) 364 repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1) 365 splits = array_ops.stack([0, self.num_slices_in_dimension(axis)]) 366 else: 367 splits = math_ops.range( 368 array_ops.size(lengths, out_type=dtypes.int64) + 1) 369 repeats = lengths 370 371 partitioned_sizes.append(lengths) 372 373 for dim_size in self._partitioned_dim_sizes[axis + 1:]: 374 if dim_size.shape.ndims == 0: 375 partitioned_sizes.append(dim_size) 376 splits *= dim_size 377 else: 378 partitioned_sizes.append( 379 ragged_util.repeat_ranges(dim_size, splits, repeats)) 380 splits = array_ops.gather( 381 ragged_util.lengths_to_splits(dim_size), splits) 382 inner_sizes = self._inner_dim_sizes 383 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) 384 385 def _broadcast_inner_dimension_to_uniform(self, axis, length): 386 """Broadcasts the inner dimension `axis` to match `lengths`.""" 387 dim_size = self.dimension_size(axis) 388 axis_in_inner_dims = axis - self.num_partitioned_dimensions 389 partitioned_sizes = self._partitioned_dim_sizes 390 inner_sizes = array_ops.concat([ 391 self._inner_dim_sizes[:axis_in_inner_dims], 392 [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)], 393 self._inner_dim_sizes[axis_in_inner_dims + 1:] 394 ], 395 axis=0) 396 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) 397 398 def _broadcast_inner_dimension_to_ragged(self, axis, lengths): 399 axis_in_inner_dims = axis - self.num_partitioned_dimensions 400 partitioned_sizes = ( 401 self._partitioned_dim_sizes + tuple([ 402 self._inner_dim_sizes[i] for i in range(axis_in_inner_dims) 403 ]) + (lengths,)) 404 inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:] 405 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) 406 407 408def broadcast_dynamic_shape(shape_x, shape_y): 409 """Returns the shape formed by broadcasting two shapes to be compatible. 410 411 Args: 412 shape_x: A `RaggedTensorDynamicShape` 413 shape_y: A `RaggedTensorDynamicShape` 414 415 Returns: 416 A `RaggedTensorDynamicShape`. 417 Raises: 418 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible. 419 """ 420 if not isinstance(shape_x, RaggedTensorDynamicShape): 421 raise TypeError('shape_x must be a RaggedTensorDynamicShape') 422 if not isinstance(shape_y, RaggedTensorDynamicShape): 423 raise TypeError('shape_y must be a RaggedTensorDynamicShape') 424 425 # Broadcast both shapes to have the same rank. 426 if shape_x.rank is None or shape_y.rank is None: 427 raise ValueError('Unable to broadcast: unknown rank') 428 broadcast_rank = max(shape_x.rank, shape_y.rank) 429 shape_x = shape_x.broadcast_to_rank(broadcast_rank) 430 shape_y = shape_y.broadcast_to_rank(broadcast_rank) 431 432 # Broadcast dimensions one at a time, starting from the outermost dimension. 433 for axis in range(broadcast_rank): 434 shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis)) 435 shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis)) 436 437 return shape_x 438 439 440def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True): 441 """Broadcasts a potentially ragged tensor to a ragged shape. 442 443 Tiles `rt_input` as necessary to match the given shape. 444 445 Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`. 446 447 Args: 448 rt_input: The potentially ragged tensor to broadcast. 449 shape: A `RaggedTensorDynamicShape` 450 broadcast_inner_dimensions: If false, then inner dimensions will not be 451 tiled. 452 453 Returns: 454 A potentially ragged tensor whose values are taken from 455 `rt_input`, and whose shape matches `shape`. 456 """ 457 if not isinstance(shape, RaggedTensorDynamicShape): 458 raise TypeError('shape must be a RaggedTensorDynamicShape') 459 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) 460 461 # Broadcasting to a uniform shape. 462 if shape.num_partitioned_dimensions == 0: 463 return _broadcast_to_uniform_shape(rt_input, shape, 464 broadcast_inner_dimensions) 465 else: 466 return _broadcast_to_ragged_shape(rt_input, shape, 467 broadcast_inner_dimensions) 468 469 470def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions): 471 """Broadcasts rt_input to the uniform shape `shape`.""" 472 if isinstance(rt_input, ragged_tensor.RaggedTensor): 473 raise ValueError('Incompatible with shape: ragged rank mismatch') 474 if broadcast_inner_dimensions: 475 return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes) 476 else: 477 return rt_input 478 479 480def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): 481 """Broadcasts rt_input to the ragged shape `dst_shape`.""" 482 # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's 483 if rt_input.shape.ndims is None or dst_shape.rank is None: 484 raise ValueError('Unable to broadcast: unknown rank') 485 if rt_input.shape.ndims > dst_shape.rank: 486 raise ValueError('Incompatible with shape: rank mismatch') 487 if (isinstance(rt_input, ragged_tensor.RaggedTensor) and 488 rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions): 489 raise ValueError('Incompatible with shape: ragged rank mismatch') 490 491 src_shape = RaggedTensorDynamicShape.from_tensor(rt_input) 492 src_shape = src_shape.broadcast_to_rank(dst_shape.rank) 493 494 # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape. 495 if dst_shape.rank > rt_input.shape.ndims: 496 if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1: 497 rt_input = array_ops.reshape( 498 rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)) 499 for _ in range(dst_shape.rank - rt_input.shape.ndims): 500 if ragged_tensor.is_ragged(rt_input): 501 nrows = rt_input.nrows() 502 else: 503 nrows = array_ops.shape(rt_input, out_type=dtypes.int64)[0] 504 rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows]) 505 506 # Add ragged dimensions to match dst_shape. 507 if ragged_tensor.is_ragged(rt_input): 508 inner_rank_diff = ( 509 rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) 510 if inner_rank_diff > 0: 511 rt_input = rt_input.with_flat_values( 512 ragged_conversion_ops.from_tensor( 513 rt_input.flat_values, ragged_rank=inner_rank_diff)) 514 else: 515 rt_input = ragged_conversion_ops.from_tensor( 516 rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1) 517 518 # Do broadcasting for any dimensions that will remain uniform. We can do 519 # these all at once, since they're independent of one another. 520 multiples = [1] * dst_shape.rank 521 for axis in range(dst_shape.num_partitioned_dimensions): 522 if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis): 523 src_size = src_shape.dimension_size(axis) 524 dst_size = dst_shape.dimension_size(axis) 525 if ((tensor_util.constant_value(src_size) in (1, None)) and 526 (tensor_util.constant_value(dst_size) != 1)): 527 multiples[axis] = array_ops.where( 528 math_ops.equal(src_size, 1), dst_size, 1) 529 if not all(isinstance(v, int) and v == 1 for v in multiples): 530 multiples = array_ops.stack(multiples, axis=0) 531 rt_input = ragged_array_ops.tile(rt_input, multiples) 532 533 if broadcast_inner_dimensions: 534 rt_input = rt_input.with_flat_values( 535 array_ops.reshape( 536 rt_input.flat_values, 537 array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))) 538 539 # Do broadcasting for dimensions that become ragged. We must do these from 540 # outermost to innermost. 541 for axis in range(dst_shape.num_partitioned_dimensions): 542 if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): 543 dst_size = dst_shape.dimension_size(axis) 544 rt_input = _ragged_tile_axis(rt_input, axis, dst_size) 545 546 return rt_input 547 548 549def _ragged_tile_axis(rt_input, axis, repeats): 550 """Tile a dimension of a RaggedTensor to match a ragged shape.""" 551 assert axis > 0 # Outermost dimension may not be ragged. 552 553 if not ragged_tensor.is_ragged(rt_input): 554 rt_input = ragged_conversion_ops.from_tensor(rt_input, ragged_rank=1) 555 556 if axis > 1: 557 return rt_input.with_values( 558 _ragged_tile_axis(rt_input.values, axis - 1, repeats)) 559 else: 560 src_row_splits = rt_input.nested_row_splits 561 src_row_lengths = rt_input.nested_row_lengths() 562 splits = src_row_splits[0] 563 564 dst_row_lengths = [repeats] 565 for i in range(1, len(src_row_lengths)): 566 dst_row_lengths.append( 567 ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats)) 568 splits = array_ops.gather(src_row_splits[i], splits) 569 dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits, 570 repeats) 571 return ragged_tensor.RaggedTensor.from_nested_row_lengths( 572 dst_values, dst_row_lengths) 573