1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Private convenience functions for RaggedTensors.
16
17None of these methods are exposed in the main "ragged" package.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import check_ops
26from tensorflow.python.ops import gen_ragged_math_ops
27from tensorflow.python.ops import math_ops
28
29
30
31def assert_splits_match(nested_splits_lists):
32  """Checks that the given splits lists are identical.
33
34  Performs static tests to ensure that the given splits lists are identical,
35  and returns a list of control dependency op tensors that check that they are
36  fully identical.
37
38  Args:
39    nested_splits_lists: A list of nested_splits_lists, where each split_list is
40      a list of `splits` tensors from a `RaggedTensor`, ordered from outermost
41      ragged dimension to innermost ragged dimension.
42
43  Returns:
44    A list of control dependency op tensors.
45  Raises:
46    ValueError: If the splits are not identical.
47  """
48  error_msg = "Inputs must have identical ragged splits"
49  for splits_list in nested_splits_lists:
50    if len(splits_list) != len(nested_splits_lists[0]):
51      raise ValueError(error_msg)
52  return [
53      check_ops.assert_equal(s1, s2, message=error_msg)
54      for splits_list in nested_splits_lists[1:]
55      for (s1, s2) in zip(nested_splits_lists[0], splits_list)
56  ]
57
58
59# Note: imported here to avoid circular dependency of array_ops.
60get_positive_axis = array_ops.get_positive_axis
61convert_to_int_tensor = array_ops.convert_to_int_tensor
62repeat = array_ops.repeat_with_axis
63
64
65def lengths_to_splits(lengths):
66  """Returns splits corresponding to the given lengths."""
67  return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1)
68
69
70def repeat_ranges(params, splits, repeats):
71  """Repeats each range of `params` (as specified by `splits`) `repeats` times.
72
73  Let the `i`th range of `params` be defined as
74  `params[splits[i]:splits[i + 1]]`.  Then this function returns a tensor
75  containing range 0 repeated `repeats[0]` times, followed by range 1 repeated
76  `repeats[1]`, ..., followed by the last range repeated `repeats[-1]` times.
77
78  Args:
79    params: The `Tensor` whose values should be repeated.
80    splits: A splits tensor indicating the ranges of `params` that should be
81      repeated.
82    repeats: The number of times each range should be repeated.  Supports
83      broadcasting from a scalar value.
84
85  Returns:
86    A `Tensor` with the same rank and type as `params`.
87
88  #### Example:
89
90  >>> print(repeat_ranges(
91  ...     params=tf.constant(['a', 'b', 'c']),
92  ...     splits=tf.constant([0, 2, 3]),
93  ...     repeats=tf.constant(3)))
94  tf.Tensor([b'a' b'b' b'a' b'b' b'a' b'b' b'c' b'c' b'c'],
95      shape=(9,), dtype=string)
96  """
97  # Divide `splits` into starts and limits, and repeat them `repeats` times.
98  if repeats.shape.ndims != 0:
99    repeated_starts = repeat(splits[:-1], repeats, axis=0)
100    repeated_limits = repeat(splits[1:], repeats, axis=0)
101  else:
102    # Optimization: we can just call repeat once, and then slice the result.
103    repeated_splits = repeat(splits, repeats, axis=0)
104    n_splits = array_ops.shape(repeated_splits, out_type=repeats.dtype)[0]
105    repeated_starts = repeated_splits[:n_splits - repeats]
106    repeated_limits = repeated_splits[repeats:]
107
108  # Get indices for each range from starts to limits, and use those to gather
109  # the values in the desired repetition pattern.
110  one = array_ops.ones((), repeated_starts.dtype)
111  offsets = gen_ragged_math_ops.ragged_range(
112      repeated_starts, repeated_limits, one)
113  return array_ops.gather(params, offsets.rt_dense_values)
114