1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for ragged tensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import gen_ragged_math_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.ragged import ragged_functional_ops
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.ops.ragged import segment_id_ops
33from tensorflow.python.util import dispatch
34from tensorflow.python.util.tf_export import tf_export
35
36
37#===============================================================================
38# ragged.range
39#===============================================================================
40# pylint: disable=redefined-builtin
41@tf_export('ragged.range')
42@dispatch.add_dispatch_support
43def range(starts,
44          limits=None,
45          deltas=1,
46          dtype=None,
47          name=None,
48          row_splits_dtype=dtypes.int64):
49  """Returns a `RaggedTensor` containing the specified sequences of numbers.
50
51  Each row of the returned `RaggedTensor` contains a single sequence:
52
53  ```python
54  ragged.range(starts, limits, deltas)[i] ==
55      tf.range(starts[i], limits[i], deltas[i])
56  ```
57
58  If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an
59  empty list.  Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then
60  `output[i]` will be an empty list.  This behavior is consistent with the
61  Python `range` function, but differs from the `tf.range` op, which returns
62  an error for these cases.
63
64  Examples:
65
66  >>> tf.ragged.range([3, 5, 2]).to_list()
67  [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]]
68  >>> tf.ragged.range([0, 5, 8], [3, 3, 12]).to_list()
69  [[0, 1, 2], [], [8, 9, 10, 11]]
70  >>> tf.ragged.range([0, 5, 8], [3, 3, 12], 2).to_list()
71  [[0, 2], [], [8, 10]]
72
73  The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors.
74  The vector inputs must all have the same size.  Scalar inputs are broadcast
75  to match the size of the vector inputs.
76
77  Args:
78    starts: Vector or scalar `Tensor`.  Specifies the first entry for each range
79      if `limits` is not `None`; otherwise, specifies the range limits, and the
80      first entries default to `0`.
81    limits: Vector or scalar `Tensor`.  Specifies the exclusive upper limits for
82      each range.
83    deltas: Vector or scalar `Tensor`.  Specifies the increment for each range.
84      Defaults to `1`.
85    dtype: The type of the elements of the resulting tensor.  If not specified,
86      then a value is chosen based on the other args.
87    name: A name for the operation.
88    row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
89      tensor.  One of `tf.int32` or `tf.int64`.
90
91  Returns:
92    A `RaggedTensor` of type `dtype` with `ragged_rank=1`.
93  """
94  row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
95  if limits is None:
96    starts, limits = 0, starts
97
98  with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name:
99    starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts')
100    limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits')
101    deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas')
102
103    # infer dtype if not explicitly provided
104    if dtype is None:
105      starts, limits, deltas = _infer_matching_dtype(
106          [starts, limits, deltas],
107          [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
108
109    result = gen_ragged_math_ops.ragged_range(
110        starts, limits, deltas, Tsplits=row_splits_dtype, name=name)
111    return ragged_tensor.RaggedTensor.from_row_splits(
112        result.rt_dense_values, result.rt_nested_splits, validate=False)
113
114
115def _infer_matching_dtype(tensors, dtype_hierarchy):
116  """Infers a matching dtype for tensors, and casts them to that dtype."""
117  assert all(t.dtype in dtype_hierarchy for t in tensors)
118  inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index)
119  return [math_ops.cast(t, inferred_dtype) for t in tensors]
120
121
122ops.no_gradient('RaggedRange')
123
124#===============================================================================
125# ragged_segment_<AGGREGATE>
126#===============================================================================
127
128# Docstring template used for the raggged_segment_<AGGREGATE> ops.
129_RAGGED_SEGMENT_DOCSTRING = """\
130Computes the %(combination)s along segments of a RaggedTensor.
131
132  Returns a RaggedTensor `output` with `num_segments` rows, where the row
133  `output[i]` is formed by taking the %(combination)s of all rows of `data`
134  whose corresponding `segment_id` is `i`.
135
136  The length of the row `output[i]` will be the maximum of the lengths of
137  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
138  rows correspond to a given segment ID, then the output row for that segment
139  ID will be empty.
140
141  Args:
142    data: A `RaggedTensor` containing the values to combine.
143    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
144      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
145      Must be greater than or equal to zero, and less than `num_segments`.
146      `segment_ids` is not required to be sorted.
147    num_segments: An `int32` or `int64` scalar specifying the number of
148      distinct segment ids.
149    name: A name prefix for the returned tensor (optional).
150  Returns:
151    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
152    has the same dtype as `data`, and its shape is
153    `[num_segments] + data.shape[segment_ids.rank:]`.
154  Raises:
155    ValueError: If `segment_ids.shape` is not a prefix of `data.shape`.
156"""
157
158
159def _ragged_segment_aggregate(unsorted_segment_op,
160                              data,
161                              segment_ids,
162                              num_segments,
163                              separator=None,
164                              name=None):
165  """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
166
167  Returns a RaggedTensor `output` with `num_segments` rows, where the row
168  `output[i]` is formed by combining all rows of `data` whose corresponding
169  `segment_id` is `i`.  The values in each row are combined using
170  `unsorted_segment_op`.
171
172  The length of the row `output[i]` will be the maximum of the lengths of
173  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
174  rows correspond to a given segment ID, then the output row for that segment
175  ID will be empty.
176
177  Args:
178    unsorted_segment_op: The tensorflow `op` that should be used to combine
179      values in each row.  Must have the same signature and basic behavior as
180      `unsorted_segment_sum`, `unsorted_segment_max`, etc.
181    data: A `RaggedTensor` containing the values to be combined.
182    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
183      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
184      `segment_ids` is not required to be sorted.
185    num_segments: An `int32` or `int64` scalar.
186    separator: An optional string. Defaults to None. The separator to use when
187      joining. Only used for string types.
188    name: A name prefix for the returned tensor (optional).
189
190  Returns:
191    A `RaggedTensor` containing the aggregated values.  The returned tensor
192    has the same dtype as `data`, and its shape is
193    `[num_segments] + data.shape[segment_ids.rank:]`.
194  Raises:
195    ValueError: If segment_ids.shape is not a prefix of data.shape.
196  """
197  if not (ragged_tensor.is_ragged(data) or
198          ragged_tensor.is_ragged(segment_ids)):
199    if separator is not None:
200      # It uses unsorted_segment_join.
201      return unsorted_segment_op(data, segment_ids, num_segments, separator,
202                                 name)
203    else:
204      return unsorted_segment_op(data, segment_ids, num_segments, name)
205
206  with ops.name_scope(name, 'RaggedSegment',
207                      [data, segment_ids, num_segments]) as name:
208    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
209    segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
210        segment_ids, name='segment_ids')
211    data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids)
212    if segment_ids.dtype not in (dtypes.int32, dtypes.int64):
213      raise ValueError('segment_ids must have dtype int32 or int64.')
214
215    if ragged_tensor.is_ragged(segment_ids):
216      if not ragged_tensor.is_ragged(data):
217        raise ValueError('segment_ids.shape must be a prefix of data.shape, '
218                         'but segment_ids is ragged and data is not.')
219      check_splits = check_ops.assert_equal(
220          segment_ids.row_splits,
221          data.row_splits,
222          message='segment_ids.shape must be a prefix of data.shape')
223      with ops.control_dependencies([check_splits]):
224        return _ragged_segment_aggregate(unsorted_segment_op, data.values,
225                                         segment_ids.values, num_segments,
226                                         separator)
227
228    # Find the length of each row in data.  (shape=[data_nrows])
229    data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]
230
231    # Find the length that each output row will have.  The length of the row
232    # corresponding to segment `id` is `max(data_row_lengths[i])` where
233    # `segment_ids[i]=id`.  (shape=[output_nrows])
234    output_row_lengths = math_ops.maximum(
235        math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
236                                      num_segments), 0)
237
238    # Build the splits tensor for the output RaggedTensor.
239    output_splits = array_ops.concat([
240        array_ops.zeros([1], output_row_lengths.dtype),
241        math_ops.cumsum(output_row_lengths)
242    ],
243                                     axis=0)
244
245    # For each row in `data`, find the start & limit position where that row's
246    # values will be aggregated in output.values.
247    data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids)
248    data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths
249
250    # For each value in `data.values`, find the position where it will
251    # aggregated in `output.values`.
252    # Get the target output values index for each data values index.
253    data_val_to_out_val_index = range(data_row_to_out_row_start,
254                                      data_row_to_out_row_limit).values
255
256    # Recursively aggregate the values.
257    output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
258                                              data_val_to_out_val_index,
259                                              output_splits[-1], separator)
260    return ragged_tensor.RaggedTensor.from_row_splits(
261        output_values, output_splits, validate=False)
262
263
264def segment_sum(data, segment_ids, num_segments, name=None):
265  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
266  return _ragged_segment_aggregate(
267      math_ops.unsorted_segment_sum,
268      data=data,
269      segment_ids=segment_ids,
270      num_segments=num_segments,
271      name=(name or 'RaggedSegmentSum'))
272
273
274def segment_prod(data, segment_ids, num_segments, name=None):
275  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
276  return _ragged_segment_aggregate(
277      math_ops.unsorted_segment_prod,
278      data=data,
279      segment_ids=segment_ids,
280      num_segments=num_segments,
281      name=(name or 'RaggedSegmentProd'))
282
283
284def segment_min(data, segment_ids, num_segments, name=None):
285  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
286  return _ragged_segment_aggregate(
287      math_ops.unsorted_segment_min,
288      data=data,
289      segment_ids=segment_ids,
290      num_segments=num_segments,
291      name=(name or 'RaggedSegmentMin'))
292
293
294def segment_max(data, segment_ids, num_segments, name=None):
295  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
296  return _ragged_segment_aggregate(
297      math_ops.unsorted_segment_max,
298      data=data,
299      segment_ids=segment_ids,
300      num_segments=num_segments,
301      name=(name or 'RaggedSegmentMax'))
302
303
304def segment_mean(data, segment_ids, num_segments, name=None):
305  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
306  with ops.name_scope(name, 'RaggedSegmentMean',
307                      [data, segment_ids, num_segments]):
308    total = segment_sum(data, segment_ids, num_segments)
309    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
310        array_ops.ones_like(data.flat_values),
311        data.nested_row_splits,
312        validate=False)
313    count = segment_sum(ones, segment_ids, num_segments)
314    if ragged_tensor.is_ragged(total):
315      return total.with_flat_values(total.flat_values / count.flat_values)
316    else:
317      return total / count
318
319
320def segment_sqrt_n(data, segment_ids, num_segments, name=None):
321  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
322  with ops.name_scope(name, 'RaggedSegmentSqrtN',
323                      [data, segment_ids, num_segments]):
324    total = segment_sum(data, segment_ids, num_segments)
325    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
326        array_ops.ones_like(data.flat_values),
327        data.nested_row_splits,
328        validate=False)
329    count = segment_sum(ones, segment_ids, num_segments)
330    if ragged_tensor.is_ragged(total):
331      return total.with_flat_values(total.flat_values /
332                                    math_ops.sqrt(count.flat_values))
333    else:
334      return total / math_ops.sqrt(count)
335
336
337def _set_ragged_segment_docstring(func, combination, combined):
338  func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict(
339      combination=combination, combined=combined)
340
341
342_set_ragged_segment_docstring(segment_sum, 'sum', 'summed')
343_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied')
344_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized')
345_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized')
346_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged')
347_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
348                              'summed')
349
350#===============================================================================
351# ragged_reduce_<AGGREGATE>
352#===============================================================================
353
354# Docstring template used for ragged_reduce_<AGGREGATE> ops.
355_RAGGED_REDUCE_DOCSTRING = """\
356Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
357
358  Reduces `input_tensor` along the dimensions given in `axis` by taking the
359  %(combination)s of values.  If a reduced dimension has no elements for
360  some index, then the value for that index will be %(default)s.
361
362  The rank of the tensor is reduced by `1` for each entry in `axis`.  If
363  `axis` is not specified, then all dimensions are reduced, and a scalar
364  value is returned.
365  Args:
366    input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
367    axis: The dimensions to reduce.  May be `None` (to reduce all axes), an
368      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
369      a given set of axes), or a `Tensor` with a constant value.  Must be in
370      the range `[0, input_tensor.rank]`.
371    name: A name prefix for the returned tensor (optional).
372  Returns:
373    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
374    has the same dtype as `data`, and its shape is given by removing the
375    dimensions specified in `axis` from `input_tensor.shape`.  The `ragged_rank`
376    of the returned tensor is given by substracting any ragged dimensions
377    specified in `axis` from `input_tensor.ragged_rank`.
378  Raises:
379    ValueError: If `axis` contains a `Tensor` whose value is not constant.
380  ####Example:
381    %(example)s
382"""
383_RAGGED_REDUCE_SUM_EXAMPLE = """
384    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
385    >>> tf.reduce_sum(rt, axis=0).numpy()  # = [3+1+9+2, 1+5+6, 4]
386    array([15, 12, 4], dtype=int32)
387    >>> tf.reduce_sum(rt, axis=1).numpy()  # = [3+1+4, 1+5, 9, 2+6]
388    array([8, 6, 9, 8], dtype=int32)
389"""
390_RAGGED_REDUCE_PROD_EXAMPLE = """
391    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
392    >>> tf.reduce_prod(rt, axis=0).numpy()  # = [3*1*9*2, 1*5*6, 4]
393    array([54, 30, 4], dtype=int32)
394    >>> tf.reduce_prod(rt, axis=1).numpy()  # = [3*1*4, 1*5, 9, 2*6]
395    array([12, 5, 9, 12], dtype=int32)
396"""
397_RAGGED_REDUCE_MIN_EXAMPLE = """
398    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
399    >>> tf.reduce_min(rt, axis=0).numpy()
400    array([1, 1, 4], dtype=int32)
401    >>> tf.reduce_min(rt, axis=1).numpy()
402    array([1, 1, 9, 2], dtype=int32)
403"""
404_RAGGED_REDUCE_MAX_EXAMPLE = """
405    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
406    >>> tf.reduce_max(rt, axis=0).numpy()
407    array([9, 6, 4], dtype=int32)
408    >>> tf.reduce_max(rt, axis=1).numpy()
409    array([4, 5, 9, 6], dtype=int32)
410"""
411_RAGGED_REDUCE_MEAN_EXAMPLE = """
412    >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
413    >>> tf.reduce_mean(rt, axis=0).numpy()
414    array([3.75, 4.  , 4. ])
415    >>> tf.reduce_mean(rt, axis=1).numpy()
416    array([2.66666667, 3.  , 9.  , 4.  ])
417"""
418_RAGGED_REDUCE_ALL_EXAMPLE = """
419    >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]])
420    >>> tf.reduce_all(rt, axis=0).numpy()
421    array([False,  True, False,  True])
422    >>> tf.reduce_all(rt, axis=1).numpy()
423    array([ True, False, False])
424"""
425_RAGGED_REDUCE_ANY_EXAMPLE = """
426    >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]])
427    >>> tf.reduce_any(rt, axis=0).numpy()
428    array([ True,  True, False,  True])
429    >>> tf.reduce_any(rt, axis=1).numpy()
430    array([ True,  True,  True])
431"""
432
433
434def ragged_reduce_aggregate(reduce_op,
435                            unsorted_segment_op,
436                            rt_input,
437                            axis,
438                            keepdims,
439                            separator=None,
440                            name=None):
441  """Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
442
443  Reduces `rt_input` along the dimensions given in `axis`.  The rank of the
444  tensor is reduced by 1 for each entry in `axis`.  If `axis` is not specified,
445  then all dimensions are reduced, and a scalar value is returned.
446
447  This op assumes that `reduce_op` and `unsorted_segment_op` are associative;
448  if not, then reducing multiple axes will return incorrect results.  (In
449  particular, reducing multiple axes is currently implemented by reducing the
450  axes one at a time.)
451
452  Args:
453    reduce_op: The tensorflow `op` that should be used to reduce values in
454      uniform dimensions.  Must have the same signature and basic behavior as
455      `reduce_sum`, `reduce_max`, etc.
456    unsorted_segment_op: The tensorflow `op` that should be used to combine
457      values in ragged dimensions.  Must have the same signature and basic
458      behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc.
459    rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced.
460    axis: The axis or axes to reduce.  May be `None` (to reduce all axes), an
461      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
462      given set of axes), or a `Tensor` with a constant value.  Must be in the
463      range `[0, rt_input.rank)`.
464    keepdims: If true, retains reduced dimensions with length 1.
465    separator: An optional string. Defaults to None. The separator to use when
466      joining. The separator must not be set for non-string data types. (i.e. if
467      separator is not None then it uses string ops)
468    name: A name prefix for the returned tensor (optional).
469
470  Returns:
471    A `RaggedTensor` containing the reduced values.  The returned tensor
472    has the same dtype as `data`, and its shape is given by removing the
473    dimensions specified in `axis` from `rt_input.shape`.  The `ragged_rank`
474    of the returned tensor is given by substracting any ragged dimensions
475    specified in `axis` from `rt_input.ragged_rank`.
476  Raises:
477    ValueError: If `axis` contains a `Tensor` whose value is not constant.
478  """
479  if not ragged_tensor.is_ragged(rt_input):
480    if separator is None:
481      return reduce_op(rt_input, axis, keepdims=keepdims, name=name)
482    else:
483      # When separator is not None, We infer that dtype is string and
484      # reduce_join will be called.
485      return reduce_op(
486          rt_input, axis, keepdims=keepdims, name=name, separator=separator)
487
488  if isinstance(axis, ops.Tensor):
489    axis = tensor_util.constant_value(axis)
490    if axis is None:
491      raise ValueError('axis must be known at graph construction time.')
492    if isinstance(axis, np.ndarray):
493      axis = axis.tolist()
494
495  # When reducing all axes, just ignore splits & reduce the inner values.
496  if axis is None:
497    result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, name=name)
498    if keepdims:
499      # Expand the result to the input number of dimensions.
500      for _ in rt_input.shape[1:]:
501        result = array_ops.expand_dims(result, axis=0)
502    return result
503
504  with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]):
505    if isinstance(axis, (tuple, list)):
506      if not axis:
507        return rt_input
508      elif len(axis) == 1:
509        axis = axis[0]
510      else:
511        # When reducing multiple axes, as we reduce one at a time (see below),
512        # the negative axis has to be converted to positive at the first run
513        # as the sort with negative axis will have different orders.
514        # See GitHub issue 27497.
515        axis = [
516            array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i,
517                                        'rank(input_tensor)')
518            for i, a in enumerate(axis)
519        ]
520        # When reducing multiple axes, just reduce one at a time.  This is less
521        # efficient, and only works for associative ops.  (In particular, it
522        # does not work for reduce_mean.)  However, reducing multiple axes at
523        # once will probably require a nontrivial c++ op.
524        axis = sorted(axis)
525        inner_reduced = ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
526                                                rt_input, axis[-1], keepdims,
527                                                separator)
528        return ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
529                                       inner_reduced, axis[:-1], keepdims,
530                                       separator)
531
532    rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
533        rt_input, name='rt_input')
534
535    axis = array_ops.get_positive_axis(
536        axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)')
537
538    if axis == 0:
539      # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]
540      row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1]
541      num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0)
542      segment_ids = range(row_lengths).values
543      result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
544                                         segment_ids, num_segments, separator)
545      if keepdims:
546        result = array_ops.expand_dims(result, axis=0)
547      return result
548    elif axis == 1:
549      # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N]
550      num_segments = array_ops.shape(rt_input.row_splits)[0] - 1
551      segment_ids = segment_id_ops.row_splits_to_segment_ids(
552          rt_input.row_splits)
553      result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
554                                         segment_ids, num_segments, separator)
555      if keepdims:
556        result = array_ops.expand_dims(result, axis=1)
557      return result
558    else:
559      # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] =
560      #     sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
561      return rt_input.with_values(
562          ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
563                                  rt_input.values, axis - 1, keepdims,
564                                  separator))
565
566
567def reduce_sum(input_tensor, axis=None, keepdims=None, name=None):
568  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
569
570  return ragged_reduce_aggregate(
571      reduce_op=math_ops.reduce_sum,
572      unsorted_segment_op=math_ops.unsorted_segment_sum,
573      rt_input=input_tensor,
574      axis=axis,
575      keepdims=keepdims,
576      name=(name or 'RaggedReduceSum'))
577
578
579def reduce_prod(input_tensor, axis=None, keepdims=None, name=None):
580  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
581  return ragged_reduce_aggregate(
582      reduce_op=math_ops.reduce_prod,
583      unsorted_segment_op=math_ops.unsorted_segment_prod,
584      rt_input=input_tensor,
585      axis=axis,
586      keepdims=keepdims,
587      name=(name or 'RaggedReduceProd'))
588
589
590def reduce_min(input_tensor, axis=None, keepdims=None, name=None):
591  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
592  return ragged_reduce_aggregate(
593      reduce_op=math_ops.reduce_min,
594      unsorted_segment_op=math_ops.unsorted_segment_min,
595      rt_input=input_tensor,
596      axis=axis,
597      keepdims=keepdims,
598      name=(name or 'RaggedReduceMin'))
599
600
601def reduce_max(input_tensor, axis=None, keepdims=None, name=None):
602  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
603  return ragged_reduce_aggregate(
604      reduce_op=math_ops.reduce_max,
605      unsorted_segment_op=math_ops.unsorted_segment_max,
606      rt_input=input_tensor,
607      axis=axis,
608      keepdims=keepdims,
609      name=(name or 'RaggedReduceMax'))
610
611
612def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
613  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
614  with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
615    total = reduce_sum(input_tensor, axis, keepdims)
616    if ragged_tensor.is_ragged(input_tensor):
617      ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
618          array_ops.ones_like(input_tensor.flat_values),
619          input_tensor.nested_row_splits,
620          validate=False)
621    else:
622      ones = array_ops.ones_like(input_tensor)
623    count = reduce_sum(ones, axis, keepdims)
624    if ragged_tensor.is_ragged(total):
625      return ragged_tensor.RaggedTensor.from_nested_row_splits(
626          total.flat_values / count.flat_values,
627          total.nested_row_splits,
628          validate=False)
629    else:
630      return total / count
631
632
633def _cast(input_tensor, dtype):
634  return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor,
635                                               dtype)
636
637
638def reduce_all(input_tensor, axis=None, keepdims=None, name=None):
639  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
640  with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
641    return _cast(
642        reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
643        dtypes.bool)
644
645
646def reduce_any(input_tensor, axis=None, keepdims=None, name=None):
647  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
648  with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
649    return _cast(
650        reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
651        dtypes.bool)
652
653
654def _set_ragged_reduce_docstring(func, combination, combined, default, example):
655  func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict(
656      combination=combination,
657      combined=combined,
658      default=default,
659      example=example)
660
661
662_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
663                             _RAGGED_REDUCE_SUM_EXAMPLE)
664_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
665                             _RAGGED_REDUCE_PROD_EXAMPLE)
666_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
667                             '`input_tensor.dtype.min`',
668                             _RAGGED_REDUCE_MIN_EXAMPLE)
669_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
670                             '`input_tensor.dtype.max`',
671                             _RAGGED_REDUCE_MAX_EXAMPLE)
672_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
673                             _RAGGED_REDUCE_MEAN_EXAMPLE)
674
675_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True',
676                             _RAGGED_REDUCE_ALL_EXAMPLE)
677_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False',
678                             _RAGGED_REDUCE_ANY_EXAMPLE)
679