1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for ragged tensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import gen_ragged_math_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.ragged import ragged_functional_ops
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.ops.ragged import ragged_util
33from tensorflow.python.ops.ragged import segment_id_ops
34from tensorflow.python.util.tf_export import tf_export
35
36
37#===============================================================================
38# ragged.range
39#===============================================================================
40# pylint: disable=redefined-builtin
41@tf_export('ragged.range')
42def range(starts, limits=None, deltas=1, dtype=None, name=None):
43  """Returns a `RaggedTensor` containing the specified sequences of numbers.
44
45  Each row of the returned `RaggedTensor` contains a single sequence:
46
47  ```python
48  ragged.range(starts, limits, deltas)[i] ==
49      tf.range(starts[i], limits[i], deltas[i])
50  ```
51
52  If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an
53  empty list.  Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then
54  `output[i]` will be an empty list.  This behavior is consistent with the
55  Python `range` function, but differs from the `tf.range` op, which returns
56  an error for these cases.
57
58  Examples:
59
60  ```python
61  >>> ragged.range([3, 5, 2]).eval().tolist()
62  [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]]
63  >>> ragged.range([0, 5, 8], [3, 3, 12]).eval().tolist()
64  [[0, 1, 2], [], [8, 9, 10, 11]]
65  >>> ragged.range([0, 5, 8], [3, 3, 12], 2).eval().tolist()
66  [[0, 2], [], [8, 10]]
67  ```
68
69  The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors.
70  The vector inputs must all have the same size.  Scalar inputs are broadcast
71  to match the size of the vector inputs.
72
73  Args:
74    starts: Vector or scalar `Tensor`.  Specifies the first entry for each range
75      if `limits` is not `None`; otherwise, specifies the range limits, and the
76      first entries default to `0`.
77    limits: Vector or scalar `Tensor`.  Specifies the exclusive upper limits for
78      each range.
79    deltas: Vector or scalar `Tensor`.  Specifies the increment for each range.
80      Defaults to `1`.
81    dtype: The type of the elements of the resulting tensor.  If not specified,
82      then a value is chosen based on the other args.
83    name: A name for the operation.
84
85  Returns:
86    A `RaggedTensor` of type `dtype` with `ragged_rank=1`.
87  """
88  if limits is None:
89    starts, limits = 0, starts
90
91  with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name:
92    starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts')
93    limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits')
94    deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas')
95
96    # infer dtype if not explicitly provided
97    if dtype is None:
98      starts, limits, deltas = _infer_matching_dtype(
99          [starts, limits, deltas],
100          [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
101
102    result = gen_ragged_math_ops.ragged_range(starts, limits, deltas, name=name)
103    return ragged_tensor.RaggedTensor.from_row_splits(result.rt_dense_values,
104                                                      result.rt_nested_splits)
105
106
107def _infer_matching_dtype(tensors, dtype_hierarchy):
108  """Infers a matching dtype for tensors, and casts them to that dtype."""
109  assert all(t.dtype in dtype_hierarchy for t in tensors)
110  inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index)
111  return [math_ops.cast(t, inferred_dtype) for t in tensors]
112
113
114#===============================================================================
115# ragged_segment_<AGGREGATE>
116#===============================================================================
117
118# Docstring template used for the raggged_segment_<AGGREGATE> ops.
119_RAGGED_SEGMENT_DOCSTRING = """\
120Computes the %(combination)s along segments of a RaggedTensor.
121
122  Returns a RaggedTensor `output` with `num_segments` rows, where the row
123  `output[i]` is formed by taking the %(combination)s of all rows of `data`
124  whose corresponding `segment_id` is `i`.
125
126  The length of the row `output[i]` will be the maximum of the lengths of
127  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
128  rows correspond to a given segment ID, then the output row for that segment
129  ID will be empty.
130
131  Args:
132    data: A `RaggedTensor` containing the values to combine.
133    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
134      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
135      Must be greater than or equal to zero, and less than `num_segments`.
136      `segment_ids` is not required to be sorted.
137    num_segments: An `int32` or `int64` scalar specifying the number of
138      distinct segment ids.
139    name: A name prefix for the returned tensor (optional).
140  Returns:
141    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
142    has the same dtype as `data`, and its shape is
143    `[num_segments] + data.shape[segment_ids.rank:]`.
144  Raises:
145    ValueError: If `segment_ids.shape` is not a prefix of `data.shape`.
146"""
147
148
149def _ragged_segment_aggregate(unsorted_segment_op,
150                              data,
151                              segment_ids,
152                              num_segments,
153                              name=None):
154  """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
155
156  Returns a RaggedTensor `output` with `num_segments` rows, where the row
157  `output[i]` is formed by combining all rows of `data` whose corresponding
158  `segment_id` is `i`.  The values in each row are combined using
159  `unsorted_segment_op`.
160
161  The length of the row `output[i]` will be the maximum of the lengths of
162  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
163  rows correspond to a given segment ID, then the output row for that segment
164  ID will be empty.
165
166  Args:
167    unsorted_segment_op: The tensorflow `op` that should be used to combine
168      values in each row.  Must have the same signature and basic behavior as
169      `unsorted_segment_sum`, `unsorted_segment_max`, etc.
170    data: A `RaggedTensor` containing the values to be combined.
171    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
172      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
173      `segment_ids` is not required to be sorted.
174    num_segments: An `int32` or `int64` scalar.
175    name: A name prefix for the returned tensor (optional).
176
177  Returns:
178    A `RaggedTensor` containing the aggregated values.  The returned tensor
179    has the same dtype as `data`, and its shape is
180    `[num_segments] + data.shape[segment_ids.rank:]`.
181  Raises:
182    ValueError: If segment_ids.shape is not a prefix of data.shape.
183  """
184  if not (ragged_tensor.is_ragged(data) or
185          ragged_tensor.is_ragged(segment_ids)):
186    return unsorted_segment_op(data, segment_ids, num_segments, name)
187
188  with ops.name_scope(name, 'RaggedSegment',
189                      [data, segment_ids, num_segments]) as name:
190    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
191    segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
192        segment_ids, name='segment_ids')
193
194    if ragged_tensor.is_ragged(segment_ids):
195      if not ragged_tensor.is_ragged(data):
196        raise ValueError('segment_ids.shape must be a prefix of data.shape, '
197                         'but segment_ids is ragged and data is not.')
198      check_splits = check_ops.assert_equal(
199          segment_ids.row_splits,
200          data.row_splits,
201          message='segment_ids.shape must be a prefix of data.shape')
202      with ops.control_dependencies([check_splits]):
203        return _ragged_segment_aggregate(unsorted_segment_op, data.values,
204                                         segment_ids.values, num_segments, name)
205
206    segment_ids = math_ops.cast(segment_ids, dtypes.int64)
207
208    # Find the length of each row in data.  (dtype=int64, shape=[data_nrows])
209    data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]
210
211    # Find the length that each output row will have.  The length of the row
212    # corresponding to segment `id` is `max(data_row_lengths[i])` where
213    # `segment_ids[i]=id`.  (dtype=int64, shape=[output_nrows])
214    output_row_lengths = math_ops.maximum(
215        math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
216                                      num_segments), 0)
217    assert output_row_lengths.dtype == dtypes.int64
218
219    # Build the splits tensor for the output RaggedTensor.
220    output_splits = array_ops.concat([
221        array_ops.zeros([1], dtypes.int64),
222        math_ops.cumsum(output_row_lengths)
223    ],
224                                     axis=0)
225
226    # For each row in `data`, find the start & limit position where that row's
227    # values will be aggregated in output.values.
228    data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids)
229    data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths
230
231    # For each value in `data.values`, find the position where it will
232    # aggregated in `output.values`.
233    # Get the target output values index for each data values index.
234    data_val_to_out_val_index = range(data_row_to_out_row_start,
235                                      data_row_to_out_row_limit).values
236
237    # Recursively aggregate the values.
238    output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
239                                              data_val_to_out_val_index,
240                                              output_splits[-1])
241    return ragged_tensor.RaggedTensor.from_row_splits(output_values,
242                                                      output_splits)
243
244
245def segment_sum(data, segment_ids, num_segments, name=None):
246  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
247  return _ragged_segment_aggregate(math_ops.unsorted_segment_sum, data,
248                                   segment_ids, num_segments, name or
249                                   'RaggedSegmentSum')
250
251
252def segment_prod(data, segment_ids, num_segments, name=None):
253  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
254  return _ragged_segment_aggregate(math_ops.unsorted_segment_prod, data,
255                                   segment_ids, num_segments, name or
256                                   'RaggedSegmentProd')
257
258
259def segment_min(data, segment_ids, num_segments, name=None):
260  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
261  return _ragged_segment_aggregate(math_ops.unsorted_segment_min, data,
262                                   segment_ids, num_segments, name or
263                                   'RaggedSegmentMin')
264
265
266def segment_max(data, segment_ids, num_segments, name=None):
267  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
268  return _ragged_segment_aggregate(math_ops.unsorted_segment_max, data,
269                                   segment_ids, num_segments, name or
270                                   'RaggedSegmentMax')
271
272
273def segment_mean(data, segment_ids, num_segments, name=None):
274  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
275  with ops.name_scope(name, 'RaggedSegmentMean',
276                      [data, segment_ids, num_segments]):
277    total = segment_sum(data, segment_ids, num_segments)
278    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
279        array_ops.ones_like(data.flat_values), data.nested_row_splits)
280    count = segment_sum(ones, segment_ids, num_segments)
281    if ragged_tensor.is_ragged(total):
282      return total.with_flat_values(total.flat_values / count.flat_values)
283    else:
284      return total / count
285
286
287def segment_sqrt_n(data, segment_ids, num_segments, name=None):
288  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
289  with ops.name_scope(name, 'RaggedSegmentSqrtN',
290                      [data, segment_ids, num_segments]):
291    total = segment_sum(data, segment_ids, num_segments)
292    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
293        array_ops.ones_like(data.flat_values), data.nested_row_splits)
294    count = segment_sum(ones, segment_ids, num_segments)
295    if ragged_tensor.is_ragged(total):
296      return total.with_flat_values(
297          total.flat_values / math_ops.sqrt(count.flat_values))
298    else:
299      return total / math_ops.sqrt(count)
300
301
302def _set_ragged_segment_docstring(func, combination, combined):
303  func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict(
304      combination=combination, combined=combined)
305
306
307_set_ragged_segment_docstring(segment_sum, 'sum', 'summed')
308_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied')
309_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized')
310_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized')
311_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged')
312_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
313                              'summed')
314
315#===============================================================================
316# ragged_reduce_<AGGREGATE>
317#===============================================================================
318
319# Docstring template used for ragged_reduce_<AGGREGATE> ops.
320_RAGGED_REDUCE_DOCSTRING = """\
321Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
322
323  Reduces `input_tensor` along the dimensions given in `axis` by taking the
324  %(combination)s of values.  If a reduced dimension has no elements for
325  some index, then the value for that index will be %(default)s.
326
327  The rank of the tensor is reduced by `1` for each entry in `axis`.  If
328  `axis` is not specified, then all dimensions are reduced, and a scalar
329  value is returned.
330  Args:
331    input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
332    axis: The dimensions to reduce.  May be `None` (to reduce all axes), an
333      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
334      a given set of axes), or a `Tensor` with a constant value.  Must be in
335      the range `[0, input_tensor.rank]`.
336    name: A name prefix for the returned tensor (optional).
337  Returns:
338    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
339    has the same dtype as `data`, and its shape is given by removing the
340    dimensions specified in `axis` from `input_tensor.shape`.  The `ragged_rank`
341    of the returned tensor is given by substracting any ragged dimensions
342    specified in `axis` from `input_tensor.ragged_rank`.
343  Raises:
344    ValueError: If `axis` contains a `Tensor` whose value is not constant.
345  ####Example:
346    ```python%(example)s    ```
347"""
348_RAGGED_REDUCE_SUM_EXAMPLE = """
349    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
350    >>> ragged.reduce_sum(rt, axis=0).eval().tolist()
351    [15, 12, 4]  # = [3+1+9+2, 1+5+6, 4]
352    >>> ragged.reduce_sum(rt, axis=1).eval().tolist()
353    [8, 6, 9, 8]  # = [3+1+4, 1+5, 9, 2+6]
354"""
355_RAGGED_REDUCE_PROD_EXAMPLE = """
356    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
357    >>> ragged.reduce_prod(rt, axis=0).eval().tolist()
358    [54, 30, 4]  # = [3*1*9*2, 1*5*6, 4]
359    >>> ragged.reduce_prod(rt, axis=1).eval().tolist()
360    [12, 5, 9, 12]  # = [3*1*4, 1*5, 9, 2*6]
361"""
362_RAGGED_REDUCE_MIN_EXAMPLE = """
363    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
364    >>> ragged.reduce_min(rt, axis=0).eval().tolist()
365    [1, 1, 4]  # = [min(3, 1, 9, 2), min(1, 5, 6), 4]
366    >>> ragged.reduce_min(rt, axis=1).eval().tolist()
367    [1, 1, 9, 2]  # = [min(3, 1, 4), min(1, 5), 9, min(2, 6)]
368"""
369_RAGGED_REDUCE_MAX_EXAMPLE = """
370    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
371    >>> ragged.reduce_max(rt, axis=0).eval().tolist()
372    [9, 6, 4]  # = [max(3, 1, 9, 2), max(1, 5, 6), 4]
373    >>> ragged.reduce_max(rt, axis=1).eval().tolist()
374    [4, 5, 9, 6]  # = [max(3, 1, 4), max(1, 5), 9, max(2, 6)]
375"""
376_RAGGED_REDUCE_MEAN_EXAMPLE = """
377    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
378    >>> ragged.reduce_mean(rt, axis=0).eval().tolist()
379    [3.75, 4, 4]  # = [mean(3, 1, 9, 2), mean(1, 5, 6), 4]
380    >>> ragged.reduce_mean(rt, axis=1).eval().tolist()
381    [2.66666, 3, 9, 4]  # = [mean(3, 1, 4), mean(1, 5), 9, mean(2, 6)]
382"""
383_RAGGED_REDUCE_ALL_EXAMPLE = """
384    >>> rt = ragged.constant([[True, True], [True, True, False, True], [False, True]])
385    >>> ragged.reduce_all(rt, axis=0).eval().tolist()
386    [False, True, False, True]
387    >>> ragged.reduce_all(rt, axis=1).eval().tolist()
388    [True, False, False]
389"""
390_RAGGED_REDUCE_ANY_EXAMPLE = """
391    >>> rt = ragged.constant([[True, True], [True, True, False, True], [False, True]])
392    >>> ragged.reduce_any(rt, axis=0).eval().tolist()
393    [True, True, False, True]
394    >>> ragged.reduce_any(rt, axis=1).eval().tolist()
395    [True, True, True]
396"""
397
398
399def _ragged_reduce_aggregate(reduce_op,
400                             unsorted_segment_op,
401                             rt_input,
402                             axis,
403                             keepdims,
404                             name=None):
405  """Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
406
407  Reduces `rt_input` along the dimensions given in `axis`.  The rank of the
408  tensor is reduced by 1 for each entry in `axis`.  If `axis` is not specified,
409  then all dimensions are reduced, and a scalar value is returned.
410
411  This op assumes that `reduce_op` and `unsorted_segment_op` are associative;
412  if not, then reducing multiple axes will return incorrect results.  (In
413  particular, reducing multiple axes is currently implemented by reducing the
414  axes one at a time.)
415
416  Args:
417    reduce_op: The tensorflow `op` that should be used to reduce values in
418      uniform dimensions.  Must have the same signature and basic behavior as
419      `reduce_sum`, `reduce_max`, etc.
420    unsorted_segment_op: The tensorflow `op` that should be used to combine
421      values in ragged dimensions.  Must have the same signature and basic
422      behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc.
423    rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced.
424    axis: The axis or axes to reduce.  May be `None` (to reduce all axes), an
425      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
426      given set of axes), or a `Tensor` with a constant value.  Must be in the
427      range `[0, rt_input.rank)`.
428    keepdims: If true, retains reduced dimensions with length 1.
429    name: A name prefix for the returned tensor (optional).
430
431  Returns:
432    A `RaggedTensor` containing the reduced values.  The returned tensor
433    has the same dtype as `data`, and its shape is given by removing the
434    dimensions specified in `axis` from `rt_input.shape`.  The `ragged_rank`
435    of the returned tensor is given by substracting any ragged dimensions
436    specified in `axis` from `rt_input.ragged_rank`.
437  Raises:
438    ValueError: If `axis` contains a `Tensor` whose value is not constant.
439  """
440  if not ragged_tensor.is_ragged(rt_input):
441    return reduce_op(rt_input, axis, name=name)
442
443  if keepdims:
444    raise ValueError('keepdims=True is not supported for RaggedTensors.')
445
446  if isinstance(axis, ops.Tensor):
447    axis = tensor_util.constant_value(axis)
448    if axis is None:
449      raise ValueError('axis must be known at graph construction time.')
450    if isinstance(axis, np.ndarray):
451      axis = axis.tolist()
452
453  # When reducing all axes, just ignore splits & reduce the inner values.
454  if axis is None:
455    return reduce_op(rt_input.flat_values, None, name=name)
456
457  with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]):
458    if isinstance(axis, (tuple, list)):
459      if not axis:
460        return rt_input
461      elif len(axis) == 1:
462        axis = axis[0]
463      else:
464        # When reducing multiple axes, just reduce one at a time.  This is less
465        # efficient, and only works for associative ops.  (In particular, it
466        # does not work for reduce_mean.)  However, reducing multiple axes at
467        # once will probably require a nontrivial c++ op.
468        axis = sorted(axis)
469        inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
470                                                 rt_input, axis[-1], keepdims)
471        return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
472                                        inner_reduced, axis[:-1], keepdims)
473
474    rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
475        rt_input, name='rt_input')
476
477    axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims)
478
479    if axis == 0:
480      # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]
481      row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1]
482      num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0)
483      segment_ids = range(row_lengths).values
484      return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
485                                       segment_ids, num_segments)
486    elif axis == 1:
487      # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N]
488      num_segments = array_ops.shape(rt_input.row_splits)[0] - 1
489      segment_ids = segment_id_ops.row_splits_to_segment_ids(
490          rt_input.row_splits)
491      return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
492                                       segment_ids, num_segments)
493    else:
494      # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] =
495      #     sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
496      return rt_input.with_values(
497          _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
498                                   rt_input.values, axis - 1, keepdims))
499
500
501def reduce_sum(input_tensor, axis=None, keepdims=None, name=None):
502  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
503  return _ragged_reduce_aggregate(math_ops.reduce_sum,
504                                  math_ops.unsorted_segment_sum, input_tensor,
505                                  axis, keepdims, name or 'RaggedReduceSum')
506
507
508def reduce_prod(input_tensor, axis=None, keepdims=None, name=None):
509  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
510  return _ragged_reduce_aggregate(math_ops.reduce_prod,
511                                  math_ops.unsorted_segment_prod, input_tensor,
512                                  axis, keepdims, name or 'RaggedReduceProd')
513
514
515def reduce_min(input_tensor, axis=None, keepdims=None, name=None):
516  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
517  return _ragged_reduce_aggregate(math_ops.reduce_min,
518                                  math_ops.unsorted_segment_min, input_tensor,
519                                  axis, keepdims, name or 'RaggedReduceMin')
520
521
522def reduce_max(input_tensor, axis=None, keepdims=None, name=None):
523  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
524  return _ragged_reduce_aggregate(math_ops.reduce_max,
525                                  math_ops.unsorted_segment_max, input_tensor,
526                                  axis, keepdims, name or 'RaggedReduceMax')
527
528
529def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
530  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
531  with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
532    total = reduce_sum(input_tensor, axis, keepdims)
533    if ragged_tensor.is_ragged(input_tensor):
534      ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
535          array_ops.ones_like(input_tensor.flat_values),
536          input_tensor.nested_row_splits)
537    else:
538      ones = array_ops.ones_like(input_tensor)
539    count = reduce_sum(ones, axis, keepdims)
540    if ragged_tensor.is_ragged(total):
541      return ragged_tensor.RaggedTensor.from_nested_row_splits(
542          total.flat_values / count.flat_values, total.nested_row_splits)
543    else:
544      return total / count
545
546
547def _cast(input_tensor, dtype):
548  return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor,
549                                               dtype)
550
551
552def reduce_all(input_tensor, axis=None, keepdims=None, name=None):
553  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
554  with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
555    return _cast(
556        reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
557        dtypes.bool)
558
559
560def reduce_any(input_tensor, axis=None, keepdims=None, name=None):
561  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
562  with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
563    return _cast(
564        reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
565        dtypes.bool)
566
567
568def _set_ragged_reduce_docstring(func, combination, combined, default, example):
569  func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict(
570      combination=combination,
571      combined=combined,
572      default=default,
573      example=example)
574
575
576_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
577                             _RAGGED_REDUCE_SUM_EXAMPLE)
578_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
579                             _RAGGED_REDUCE_PROD_EXAMPLE)
580_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
581                             '`input_tensor.dtype.min`',
582                             _RAGGED_REDUCE_MIN_EXAMPLE)
583_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
584                             '`input_tensor.dtype.max`',
585                             _RAGGED_REDUCE_MAX_EXAMPLE)
586_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
587                             _RAGGED_REDUCE_MEAN_EXAMPLE)
588
589_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True',
590                             _RAGGED_REDUCE_ALL_EXAMPLE)
591_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False',
592                             _RAGGED_REDUCE_ANY_EXAMPLE)
593