1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for ragged tensors."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
21import numpy as np
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import gen_ragged_math_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.ragged import ragged_functional_ops
31from tensorflow.python.ops.ragged import ragged_tensor
32from tensorflow.python.ops.ragged import ragged_util
33from tensorflow.python.ops.ragged import segment_id_ops
34from tensorflow.python.util.tf_export import tf_export
38# ragged.range
40# pylint: disable=redefined-builtin
42def range(starts, limits=None, deltas=1, dtype=None, name=None):
43  """Returns a `RaggedTensor` containing the specified sequences of numbers.
45  Each row of the returned `RaggedTensor` contains a single sequence:
47  ```python
48  ragged.range(starts, limits, deltas)[i] ==
49      tf.range(starts[i], limits[i], deltas[i])
50  ```
52  If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an
53  empty list.  Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then
54  `output[i]` will be an empty list.  This behavior is consistent with the
55  Python `range` function, but differs from the `tf.range` op, which returns
56  an error for these cases.
58  Examples:
60  ```python
61  >>> ragged.range([3, 5, 2]).eval().tolist()
62  [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]]
63  >>> ragged.range([0, 5, 8], [3, 3, 12]).eval().tolist()
64  [[0, 1, 2], [], [8, 9, 10, 11]]
65  >>> ragged.range([0, 5, 8], [3, 3, 12], 2).eval().tolist()
66  [[0, 2], [], [8, 10]]
67  ```
69  The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors.
70  The vector inputs must all have the same size.  Scalar inputs are broadcast
71  to match the size of the vector inputs.
73  Args:
74    starts: Vector or scalar `Tensor`.  Specifies the first entry for each range
75      if `limits` is not `None`; otherwise, specifies the range limits, and the
76      first entries default to `0`.
77    limits: Vector or scalar `Tensor`.  Specifies the exclusive upper limits for
78      each range.
79    deltas: Vector or scalar `Tensor`.  Specifies the increment for each range.
80      Defaults to `1`.
81    dtype: The type of the elements of the resulting tensor.  If not specified,
82      then a value is chosen based on the other args.
83    name: A name for the operation.
85  Returns:
86    A `RaggedTensor` of type `dtype` with `ragged_rank=1`.
87  """
88  if limits is None:
89    starts, limits = 0, starts
91  with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name:
92    starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts')
93    limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits')
94    deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas')
96    # infer dtype if not explicitly provided
97    if dtype is None:
98      starts, limits, deltas = _infer_matching_dtype(
99          [starts, limits, deltas],
100          [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
102    result = gen_ragged_math_ops.ragged_range(starts, limits, deltas, name=name)
103    return ragged_tensor.RaggedTensor.from_row_splits(result.rt_dense_values,
104                                                      result.rt_nested_splits)
107def _infer_matching_dtype(tensors, dtype_hierarchy):
108  """Infers a matching dtype for tensors, and casts them to that dtype."""
109  assert all(t.dtype in dtype_hierarchy for t in tensors)
110  inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index)
111  return [math_ops.cast(t, inferred_dtype) for t in tensors]
115# ragged_segment_<AGGREGATE>
118# Docstring template used for the raggged_segment_<AGGREGATE> ops.
120Computes the %(combination)s along segments of a RaggedTensor.
122  Returns a RaggedTensor `output` with `num_segments` rows, where the row
123  `output[i]` is formed by taking the %(combination)s of all rows of `data`
124  whose corresponding `segment_id` is `i`.
126  The length of the row `output[i]` will be the maximum of the lengths of
127  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
128  rows correspond to a given segment ID, then the output row for that segment
129  ID will be empty.
131  Args:
132    data: A `RaggedTensor` containing the values to combine.
133    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
134      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
135      Must be greater than or equal to zero, and less than `num_segments`.
136      `segment_ids` is not required to be sorted.
137    num_segments: An `int32` or `int64` scalar specifying the number of
138      distinct segment ids.
139    name: A name prefix for the returned tensor (optional).
140  Returns:
141    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
142    has the same dtype as `data`, and its shape is
143    `[num_segments] + data.shape[segment_ids.rank:]`.
144  Raises:
145    ValueError: If `segment_ids.shape` is not a prefix of `data.shape`.
149def _ragged_segment_aggregate(unsorted_segment_op,
150                              data,
151                              segment_ids,
152                              num_segments,
153                              name=None):
154  """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
156  Returns a RaggedTensor `output` with `num_segments` rows, where the row
157  `output[i]` is formed by combining all rows of `data` whose corresponding
158  `segment_id` is `i`.  The values in each row are combined using
159  `unsorted_segment_op`.
161  The length of the row `output[i]` will be the maximum of the lengths of
162  all rows of `data` whose corresponding `segment_id` is `i`.  If no `data`
163  rows correspond to a given segment ID, then the output row for that segment
164  ID will be empty.
166  Args:
167    unsorted_segment_op: The tensorflow `op` that should be used to combine
168      values in each row.  Must have the same signature and basic behavior as
169      `unsorted_segment_sum`, `unsorted_segment_max`, etc.
170    data: A `RaggedTensor` containing the values to be combined.
171    segment_ids: A `Tensor` or `RaggedTensor`.  Must have type `int64` or
172      `int32`.  `segment_ids.shape` must be a prefix of `data.shape`.
173      `segment_ids` is not required to be sorted.
174    num_segments: An `int32` or `int64` scalar.
175    name: A name prefix for the returned tensor (optional).
177  Returns:
178    A `RaggedTensor` containing the aggregated values.  The returned tensor
179    has the same dtype as `data`, and its shape is
180    `[num_segments] + data.shape[segment_ids.rank:]`.
181  Raises:
182    ValueError: If segment_ids.shape is not a prefix of data.shape.
183  """
184  if not (ragged_tensor.is_ragged(data) or
185          ragged_tensor.is_ragged(segment_ids)):
186    return unsorted_segment_op(data, segment_ids, num_segments, name)
188  with ops.name_scope(name, 'RaggedSegment',
189                      [data, segment_ids, num_segments]) as name:
190    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
191    segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
192        segment_ids, name='segment_ids')
194    if ragged_tensor.is_ragged(segment_ids):
195      if not ragged_tensor.is_ragged(data):
196        raise ValueError('segment_ids.shape must be a prefix of data.shape, '
197                         'but segment_ids is ragged and data is not.')
198      check_splits = check_ops.assert_equal(
199          segment_ids.row_splits,
200          data.row_splits,
201          message='segment_ids.shape must be a prefix of data.shape')
202      with ops.control_dependencies([check_splits]):
203        return _ragged_segment_aggregate(unsorted_segment_op, data.values,
204                                         segment_ids.values, num_segments, name)
206    segment_ids = math_ops.cast(segment_ids, dtypes.int64)
208    # Find the length of each row in data.  (dtype=int64, shape=[data_nrows])
209    data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]
211    # Find the length that each output row will have.  The length of the row
212    # corresponding to segment `id` is `max(data_row_lengths[i])` where
213    # `segment_ids[i]=id`.  (dtype=int64, shape=[output_nrows])
214    output_row_lengths = math_ops.maximum(
215        math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
216                                      num_segments), 0)
217    assert output_row_lengths.dtype == dtypes.int64
219    # Build the splits tensor for the output RaggedTensor.
220    output_splits = array_ops.concat([
221        array_ops.zeros([1], dtypes.int64),
222        math_ops.cumsum(output_row_lengths)
223    ],
224                                     axis=0)
226    # For each row in `data`, find the start & limit position where that row's
227    # values will be aggregated in output.values.
228    data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids)
229    data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths
231    # For each value in `data.values`, find the position where it will
232    # aggregated in `output.values`.
233    # Get the target output values index for each data values index.
234    data_val_to_out_val_index = range(data_row_to_out_row_start,
235                                      data_row_to_out_row_limit).values
237    # Recursively aggregate the values.
238    output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
239                                              data_val_to_out_val_index,
240                                              output_splits[-1])
241    return ragged_tensor.RaggedTensor.from_row_splits(output_values,
242                                                      output_splits)
245def segment_sum(data, segment_ids, num_segments, name=None):
246  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
247  return _ragged_segment_aggregate(math_ops.unsorted_segment_sum, data,
248                                   segment_ids, num_segments, name or
249                                   'RaggedSegmentSum')
252def segment_prod(data, segment_ids, num_segments, name=None):
253  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
254  return _ragged_segment_aggregate(math_ops.unsorted_segment_prod, data,
255                                   segment_ids, num_segments, name or
256                                   'RaggedSegmentProd')
259def segment_min(data, segment_ids, num_segments, name=None):
260  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
261  return _ragged_segment_aggregate(math_ops.unsorted_segment_min, data,
262                                   segment_ids, num_segments, name or
263                                   'RaggedSegmentMin')
266def segment_max(data, segment_ids, num_segments, name=None):
267  # For docs, see: _RAGGED_SEGMENT_DOCSTRING
268  return _ragged_segment_aggregate(math_ops.unsorted_segment_max, data,
269                                   segment_ids, num_segments, name or
270                                   'RaggedSegmentMax')
273def segment_mean(data, segment_ids, num_segments, name=None):
274  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
275  with ops.name_scope(name, 'RaggedSegmentMean',
276                      [data, segment_ids, num_segments]):
277    total = segment_sum(data, segment_ids, num_segments)
278    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
279        array_ops.ones_like(data.flat_values), data.nested_row_splits)
280    count = segment_sum(ones, segment_ids, num_segments)
281    if ragged_tensor.is_ragged(total):
282      return total.with_flat_values(total.flat_values / count.flat_values)
283    else:
284      return total / count
287def segment_sqrt_n(data, segment_ids, num_segments, name=None):
288  """For docs, see: _RAGGED_SEGMENT_DOCSTRING."""
289  with ops.name_scope(name, 'RaggedSegmentSqrtN',
290                      [data, segment_ids, num_segments]):
291    total = segment_sum(data, segment_ids, num_segments)
292    ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
293        array_ops.ones_like(data.flat_values), data.nested_row_splits)
294    count = segment_sum(ones, segment_ids, num_segments)
295    if ragged_tensor.is_ragged(total):
296      return total.with_flat_values(
297          total.flat_values / math_ops.sqrt(count.flat_values))
298    else:
299      return total / math_ops.sqrt(count)
302def _set_ragged_segment_docstring(func, combination, combined):
303  func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict(
304      combination=combination, combined=combined)
307_set_ragged_segment_docstring(segment_sum, 'sum', 'summed')
308_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied')
309_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized')
310_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized')
311_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged')
312_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)',
313                              'summed')
316# ragged_reduce_<AGGREGATE>
319# Docstring template used for ragged_reduce_<AGGREGATE> ops.
321Computes the %(combination)s of elements across dimensions of a `RaggedTensor`.
323  Reduces `input_tensor` along the dimensions given in `axis` by taking the
324  %(combination)s of values.  If a reduced dimension has no elements for
325  some index, then the value for that index will be %(default)s.
327  The rank of the tensor is reduced by `1` for each entry in `axis`.  If
328  `axis` is not specified, then all dimensions are reduced, and a scalar
329  value is returned.
330  Args:
331    input_tensor: A `RaggedTensor` containing the values to be %(combined)s.
332    axis: The dimensions to reduce.  May be `None` (to reduce all axes), an
333      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce
334      a given set of axes), or a `Tensor` with a constant value.  Must be in
335      the range `[0, input_tensor.rank]`.
336    name: A name prefix for the returned tensor (optional).
337  Returns:
338    A `RaggedTensor` containing the %(combined)s values.  The returned tensor
339    has the same dtype as `data`, and its shape is given by removing the
340    dimensions specified in `axis` from `input_tensor.shape`.  The `ragged_rank`
341    of the returned tensor is given by substracting any ragged dimensions
342    specified in `axis` from `input_tensor.ragged_rank`.
343  Raises:
344    ValueError: If `axis` contains a `Tensor` whose value is not constant.
345  ####Example:
346    ```python%(example)s    ```
349    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
350    >>> ragged.reduce_sum(rt, axis=0).eval().tolist()
351    [15, 12, 4]  # = [3+1+9+2, 1+5+6, 4]
352    >>> ragged.reduce_sum(rt, axis=1).eval().tolist()
353    [8, 6, 9, 8]  # = [3+1+4, 1+5, 9, 2+6]
356    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
357    >>> ragged.reduce_prod(rt, axis=0).eval().tolist()
358    [54, 30, 4]  # = [3*1*9*2, 1*5*6, 4]
359    >>> ragged.reduce_prod(rt, axis=1).eval().tolist()
360    [12, 5, 9, 12]  # = [3*1*4, 1*5, 9, 2*6]
363    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
364    >>> ragged.reduce_min(rt, axis=0).eval().tolist()
365    [1, 1, 4]  # = [min(3, 1, 9, 2), min(1, 5, 6), 4]
366    >>> ragged.reduce_min(rt, axis=1).eval().tolist()
367    [1, 1, 9, 2]  # = [min(3, 1, 4), min(1, 5), 9, min(2, 6)]
370    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
371    >>> ragged.reduce_max(rt, axis=0).eval().tolist()
372    [9, 6, 4]  # = [max(3, 1, 9, 2), max(1, 5, 6), 4]
373    >>> ragged.reduce_max(rt, axis=1).eval().tolist()
374    [4, 5, 9, 6]  # = [max(3, 1, 4), max(1, 5), 9, max(2, 6)]
377    >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]])
378    >>> ragged.reduce_mean(rt, axis=0).eval().tolist()
379    [3.75, 4, 4]  # = [mean(3, 1, 9, 2), mean(1, 5, 6), 4]
380    >>> ragged.reduce_mean(rt, axis=1).eval().tolist()
381    [2.66666, 3, 9, 4]  # = [mean(3, 1, 4), mean(1, 5), 9, mean(2, 6)]
384    >>> rt = ragged.constant([[True, True], [True, True, False, True], [False, True]])
385    >>> ragged.reduce_all(rt, axis=0).eval().tolist()
386    [False, True, False, True]
387    >>> ragged.reduce_all(rt, axis=1).eval().tolist()
388    [True, False, False]
391    >>> rt = ragged.constant([[True, True], [True, True, False, True], [False, True]])
392    >>> ragged.reduce_any(rt, axis=0).eval().tolist()
393    [True, True, False, True]
394    >>> ragged.reduce_any(rt, axis=1).eval().tolist()
395    [True, True, True]
399def _ragged_reduce_aggregate(reduce_op,
400                             unsorted_segment_op,
401                             rt_input,
402                             axis,
403                             keepdims,
404                             name=None):
405  """Aggregates across axes of a RaggedTensor using the given `Tensor` ops.
407  Reduces `rt_input` along the dimensions given in `axis`.  The rank of the
408  tensor is reduced by 1 for each entry in `axis`.  If `axis` is not specified,
409  then all dimensions are reduced, and a scalar value is returned.
411  This op assumes that `reduce_op` and `unsorted_segment_op` are associative;
412  if not, then reducing multiple axes will return incorrect results.  (In
413  particular, reducing multiple axes is currently implemented by reducing the
414  axes one at a time.)
416  Args:
417    reduce_op: The tensorflow `op` that should be used to reduce values in
418      uniform dimensions.  Must have the same signature and basic behavior as
419      `reduce_sum`, `reduce_max`, etc.
420    unsorted_segment_op: The tensorflow `op` that should be used to combine
421      values in ragged dimensions.  Must have the same signature and basic
422      behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc.
423    rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced.
424    axis: The axis or axes to reduce.  May be `None` (to reduce all axes), an
425      `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a
426      given set of axes), or a `Tensor` with a constant value.  Must be in the
427      range `[0, rt_input.rank)`.
428    keepdims: If true, retains reduced dimensions with length 1.
429    name: A name prefix for the returned tensor (optional).
431  Returns:
432    A `RaggedTensor` containing the reduced values.  The returned tensor
433    has the same dtype as `data`, and its shape is given by removing the
434    dimensions specified in `axis` from `rt_input.shape`.  The `ragged_rank`
435    of the returned tensor is given by substracting any ragged dimensions
436    specified in `axis` from `rt_input.ragged_rank`.
437  Raises:
438    ValueError: If `axis` contains a `Tensor` whose value is not constant.
439  """
440  if not ragged_tensor.is_ragged(rt_input):
441    return reduce_op(rt_input, axis, name=name)
443  if keepdims:
444    raise ValueError('keepdims=True is not supported for RaggedTensors.')
446  if isinstance(axis, ops.Tensor):
447    axis = tensor_util.constant_value(axis)
448    if axis is None:
449      raise ValueError('axis must be known at graph construction time.')
450    if isinstance(axis, np.ndarray):
451      axis = axis.tolist()
453  # When reducing all axes, just ignore splits & reduce the inner values.
454  if axis is None:
455    return reduce_op(rt_input.flat_values, None, name=name)
457  with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]):
458    if isinstance(axis, (tuple, list)):
459      if not axis:
460        return rt_input
461      elif len(axis) == 1:
462        axis = axis[0]
463      else:
464        # When reducing multiple axes, just reduce one at a time.  This is less
465        # efficient, and only works for associative ops.  (In particular, it
466        # does not work for reduce_mean.)  However, reducing multiple axes at
467        # once will probably require a nontrivial c++ op.
468        axis = sorted(axis)
469        inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
470                                                 rt_input, axis[-1], keepdims)
471        return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
472                                        inner_reduced, axis[:-1], keepdims)
474    rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
475        rt_input, name='rt_input')
477    axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims)
479    if axis == 0:
480      # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N]
481      row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1]
482      num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0)
483      segment_ids = range(row_lengths).values
484      return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
485                                       segment_ids, num_segments)
486    elif axis == 1:
487      # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N]
488      num_segments = array_ops.shape(rt_input.row_splits)[0] - 1
489      segment_ids = segment_id_ops.row_splits_to_segment_ids(
490          rt_input.row_splits)
491      return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values,
492                                       segment_ids, num_segments)
493    else:
494      # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] =
495      #     sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N]
496      return rt_input.with_values(
497          _ragged_reduce_aggregate(reduce_op, unsorted_segment_op,
498                                   rt_input.values, axis - 1, keepdims))
501def reduce_sum(input_tensor, axis=None, keepdims=None, name=None):
502  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
503  return _ragged_reduce_aggregate(math_ops.reduce_sum,
504                                  math_ops.unsorted_segment_sum, input_tensor,
505                                  axis, keepdims, name or 'RaggedReduceSum')
508def reduce_prod(input_tensor, axis=None, keepdims=None, name=None):
509  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
510  return _ragged_reduce_aggregate(math_ops.reduce_prod,
511                                  math_ops.unsorted_segment_prod, input_tensor,
512                                  axis, keepdims, name or 'RaggedReduceProd')
515def reduce_min(input_tensor, axis=None, keepdims=None, name=None):
516  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
517  return _ragged_reduce_aggregate(math_ops.reduce_min,
518                                  math_ops.unsorted_segment_min, input_tensor,
519                                  axis, keepdims, name or 'RaggedReduceMin')
522def reduce_max(input_tensor, axis=None, keepdims=None, name=None):
523  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
524  return _ragged_reduce_aggregate(math_ops.reduce_max,
525                                  math_ops.unsorted_segment_max, input_tensor,
526                                  axis, keepdims, name or 'RaggedReduceMax')
529def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
530  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
531  with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
532    total = reduce_sum(input_tensor, axis, keepdims)
533    if ragged_tensor.is_ragged(input_tensor):
534      ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
535          array_ops.ones_like(input_tensor.flat_values),
536          input_tensor.nested_row_splits)
537    else:
538      ones = array_ops.ones_like(input_tensor)
539    count = reduce_sum(ones, axis, keepdims)
540    if ragged_tensor.is_ragged(total):
541      return ragged_tensor.RaggedTensor.from_nested_row_splits(
542          total.flat_values / count.flat_values, total.nested_row_splits)
543    else:
544      return total / count
547def _cast(input_tensor, dtype):
548  return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor,
549                                               dtype)
552def reduce_all(input_tensor, axis=None, keepdims=None, name=None):
553  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
554  with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]):
555    return _cast(
556        reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims),
557        dtypes.bool)
560def reduce_any(input_tensor, axis=None, keepdims=None, name=None):
561  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
562  with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]):
563    return _cast(
564        reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims),
565        dtypes.bool)
568def _set_ragged_reduce_docstring(func, combination, combined, default, example):
569  func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict(
570      combination=combination,
571      combined=combined,
572      default=default,
573      example=example)
576_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0',
577                             _RAGGED_REDUCE_SUM_EXAMPLE)
578_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1',
579                             _RAGGED_REDUCE_PROD_EXAMPLE)
580_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized',
581                             '`input_tensor.dtype.min`',
582                             _RAGGED_REDUCE_MIN_EXAMPLE)
583_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized',
584                             '`input_tensor.dtype.max`',
585                             _RAGGED_REDUCE_MAX_EXAMPLE)
586_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN',
587                             _RAGGED_REDUCE_MEAN_EXAMPLE)
589_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True',
590                             _RAGGED_REDUCE_ALL_EXAMPLE)
591_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False',
592                             _RAGGED_REDUCE_ANY_EXAMPLE)