1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Private convenience functions for RaggedTensors.
16
17None of these methods are exposed in the main "ragged" package.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import gen_ragged_math_ops
29from tensorflow.python.ops import math_ops
30
31
32def convert_to_int_tensor(tensor, name, dtype=dtypes.int32):
33  """Converts the given value to an integer Tensor."""
34  tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype)
35  if tensor.dtype.is_integer:
36    tensor = math_ops.cast(tensor, dtype)
37  else:
38    raise TypeError(
39        "%s must be an integer tensor; dtype=%s" % (name, tensor.dtype))
40  return tensor
41
42
43def get_positive_axis(axis, ndims):
44  """Validate an `axis` parameter, and normalize it to be positive.
45
46  If `ndims` is known (i.e., not `None`), then check that `axis` is in the
47  range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or
48  `axis + ndims` (otherwise).
49  If `ndims` is not known, and `axis` is positive, then return it as-is.
50  If `ndims` is not known, and `axis` is negative, then report an error.
51
52  Args:
53    axis: An integer constant
54    ndims: An integer constant, or `None`
55
56  Returns:
57    The normalized `axis` value.
58
59  Raises:
60    ValueError: If `axis` is out-of-bounds, or if `axis` is negative and
61      `ndims is None`.
62  """
63  if not isinstance(axis, int):
64    raise TypeError("axis must be an int; got %s" % type(axis).__name__)
65  if ndims is not None:
66    if 0 <= axis < ndims:
67      return axis
68    elif -ndims <= axis < 0:
69      return axis + ndims
70    else:
71      raise ValueError(
72          "axis=%s out of bounds: expected %s<=axis<%s" % (axis, -ndims, ndims))
73  elif axis < 0:
74    raise ValueError("axis may only be negative if ndims is statically known.")
75  return axis
76
77
78def assert_splits_match(nested_splits_lists):
79  """Checks that the given splits lists are identical.
80
81  Performs static tests to ensure that the given splits lists are identical,
82  and returns a list of control dependency op tensors that check that they are
83  fully identical.
84
85  Args:
86    nested_splits_lists: A list of nested_splits_lists, where each split_list is
87      a list of `splits` tensors from a `RaggedTensor`, ordered from outermost
88      ragged dimension to innermost ragged dimension.
89
90  Returns:
91    A list of control dependency op tensors.
92  Raises:
93    ValueError: If the splits are not identical.
94  """
95  error_msg = "Inputs must have identical ragged splits"
96  for splits_list in nested_splits_lists:
97    if len(splits_list) != len(nested_splits_lists[0]):
98      raise ValueError(error_msg)
99  return [
100      check_ops.assert_equal(s1, s2, message=error_msg)
101      for splits_list in nested_splits_lists[1:]
102      for (s1, s2) in zip(nested_splits_lists[0], splits_list)
103  ]
104
105
106# This op is intended to exactly match the semantics of numpy.repeat, with
107# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior
108# when axis is not specified.  Rather than implement that special behavior, we
109# simply make `axis` be a required argument.
110#
111# External (OSS) `tf.repeat` feature request:
112# https://github.com/tensorflow/tensorflow/issues/8246
113def repeat(data, repeats, axis, name=None):
114  """Repeats elements of `data`.
115
116  Args:
117    data: An `N`-dimensional tensor.
118    repeats: A 1-D integer tensor specifying how many times each element in
119      `axis` should be repeated.  `len(repeats)` must equal `data.shape[axis]`.
120      Supports broadcasting from a scalar value.
121    axis: `int`.  The axis along which to repeat values.  Must be less than
122      `max(N, 1)`.
123    name: A name for the operation.
124
125  Returns:
126    A tensor with `max(N, 1)` dimensions.  Has the same shape as `data`,
127    except that dimension `axis` has size `sum(repeats)`.
128
129  #### Examples:
130    ```python
131    >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0)
132    ['a', 'a', 'a', 'c', 'c']
133    >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0)
134    [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]]
135    >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1)
136    [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]]
137    ```
138  """
139  if not isinstance(axis, int):
140    raise TypeError("axis must be an int; got %s" % type(axis).__name__)
141
142  with ops.name_scope(name, "Repeat", [data, repeats]):
143    data = ops.convert_to_tensor(data, name="data")
144    repeats = convert_to_int_tensor(repeats, name="repeats")
145    repeats.shape.with_rank_at_most(1)
146
147    # If `data` is a scalar, then upgrade it to a vector.
148    data = _with_nonzero_rank(data)
149    data_shape = array_ops.shape(data)
150
151    # If `axis` is negative, then convert it to a positive value.
152    axis = get_positive_axis(axis, data.shape.ndims)
153
154    # Check data Tensor shapes.
155    if repeats.shape.ndims == 1:
156      data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0])
157
158    # If we know that `repeats` is a scalar, then we can just tile & reshape.
159    if repeats.shape.ndims == 0:
160      expanded = array_ops.expand_dims(data, axis + 1)
161      tiled = tile_one_dimension(expanded, axis + 1, repeats)
162      result_shape = array_ops.concat(
163          [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0)
164      return array_ops.reshape(tiled, result_shape)
165
166    # Broadcast the `repeats` tensor so rank(repeats) == axis + 1.
167    if repeats.shape.ndims != axis + 1:
168      repeats_shape = array_ops.shape(repeats)
169      repeats_ndims = array_ops.rank(repeats)
170      broadcast_shape = array_ops.concat(
171          [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0)
172      repeats = array_ops.broadcast_to(repeats, broadcast_shape)
173      repeats.set_shape([None] * (axis + 1))
174
175    # Create a "sequence mask" based on `repeats`, where slices across `axis`
176    # contain one `True` value for each repetition.  E.g., if
177    # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`.
178    max_repeat = math_ops.maximum(0, math_ops.reduce_max(repeats))
179    mask = array_ops.sequence_mask(repeats, max_repeat)
180
181    # Add a new dimension around each value that needs to be repeated, and
182    # then tile that new dimension to match the maximum number of repetitions.
183    expanded = array_ops.expand_dims(data, axis + 1)
184    tiled = tile_one_dimension(expanded, axis + 1, max_repeat)
185
186    # Use `boolean_mask` to discard the extra repeated values.  This also
187    # flattens all dimensions up through `axis`.
188    masked = array_ops.boolean_mask(tiled, mask)
189
190    # Reshape the output tensor to add the outer dimensions back.
191    if axis == 0:
192      result = masked
193    else:
194      result_shape = array_ops.concat(
195          [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0)
196      result = array_ops.reshape(masked, result_shape)
197
198    # Preserve shape information.
199    if data.shape.ndims is not None:
200      new_axis_size = 0 if repeats.shape[0] == 0 else None
201      result.set_shape(data.shape[:axis].concatenate(
202          [new_axis_size]).concatenate(data.shape[axis + 1:]))
203
204    return result
205
206
207def tile_one_dimension(data, axis, multiple):
208  """Tiles a single dimension of a tensor."""
209  # Assumes axis is a nonnegative int.
210  if data.shape.ndims is not None:
211    multiples = [1] * data.shape.ndims
212    multiples[axis] = multiple
213  else:
214    ones = array_ops.ones(array_ops.rank(data), dtypes.int32)
215    multiples = array_ops.concat([ones[:axis], [multiple], ones[axis + 1:]],
216                                 axis=0)
217  return array_ops.tile(data, multiples)
218
219
220def _with_nonzero_rank(data):
221  """If `data` is scalar, then add a dimension; otherwise return as-is."""
222  if data.shape.ndims is not None:
223    if data.shape.ndims == 0:
224      return array_ops.stack([data])
225    else:
226      return data
227  else:
228    data_shape = array_ops.shape(data)
229    data_ndims = array_ops.rank(data)
230    return array_ops.reshape(
231        data,
232        array_ops.concat([[1], data_shape], axis=0)[-data_ndims:])
233
234
235def lengths_to_splits(lengths):
236  """Returns splits corresponding to the given lengths."""
237  return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1)
238
239
240def repeat_ranges(params, splits, repeats):
241  """Repeats each range of `params` (as specified by `splits`) `repeats` times.
242
243  Let the `i`th range of `params` be defined as
244  `params[splits[i]:splits[i + 1]]`.  Then this function returns a tensor
245  containing range 0 repeated `repeats[0]` times, followed by range 1 repeated
246  `repeats[1]`, ..., followed by the last range repeated `repeats[-1]` times.
247
248  Args:
249    params: The `Tensor` whose values should be repeated.
250    splits: A splits tensor indicating the ranges of `params` that should be
251      repeated.
252    repeats: The number of times each range should be repeated.  Supports
253      broadcasting from a scalar value.
254
255  Returns:
256    A `Tensor` with the same rank and type as `params`.
257
258  #### Example:
259    ```python
260    >>> repeat_ranges(['a', 'b', 'c'], [0, 2, 3], 3)
261    ['a', 'b', 'a', 'b', 'a', 'b', 'c', 'c', 'c']
262    ```
263  """
264  # Divide `splits` into starts and limits, and repeat them `repeats` times.
265  if repeats.shape.ndims != 0:
266    repeated_starts = repeat(splits[:-1], repeats, axis=0)
267    repeated_limits = repeat(splits[1:], repeats, axis=0)
268  else:
269    # Optimization: we can just call repeat once, and then slice the result.
270    repeated_splits = repeat(splits, repeats, axis=0)
271    n_splits = array_ops.shape(repeated_splits, out_type=dtypes.int64)[0]
272    repeated_starts = repeated_splits[:n_splits - repeats]
273    repeated_limits = repeated_splits[repeats:]
274
275  # Get indices for each range from starts to limits, and use those to gather
276  # the values in the desired repetition pattern.
277  one = array_ops.ones((), repeated_starts.dtype)
278  offsets = gen_ragged_math_ops.ragged_range(
279      repeated_starts, repeated_limits, one)
280  return array_ops.gather(params, offsets.rt_dense_values)
281