1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for ragged tensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.ragged import ragged_config
26from tensorflow.python.ops.ragged import ragged_tensor
27from tensorflow.python.ops.ragged import ragged_util
28from tensorflow.python.util import dispatch
29from tensorflow.python.util.tf_export import tf_export
30
31
32@tf_export("ragged.map_flat_values")
33@dispatch.add_dispatch_support
34def map_flat_values(op, *args, **kwargs):
35  """Applies `op` to the `flat_values` of one or more RaggedTensors.
36
37  Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values`
38  tensor (which collapses all ragged dimensions), and then calls `op`.  Returns
39  a `RaggedTensor` that is constructed from the input `RaggedTensor`s'
40  `nested_row_splits` and the value returned by the `op`.
41
42  If the input arguments contain multiple `RaggedTensor`s, then they must have
43  identical `nested_row_splits`.
44
45  This operation is generally used to apply elementwise operations to each value
46  in a `RaggedTensor`.
47
48  Warning: `tf.ragged.map_flat_values` does *not* apply `op` to each row of a
49  ragged tensor.  This difference is important for non-elementwise operations,
50  such as `tf.reduce_sum`.  If you wish to apply a non-elementwise operation to
51  each row of a ragged tensor, use `tf.map_fn` instead.  (You may need to
52  specify an `output_signature` when using `tf.map_fn` with ragged tensors.)
53
54  Examples:
55
56  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
57  >>> tf.ragged.map_flat_values(tf.ones_like, rt)
58  <tf.RaggedTensor [[1, 1, 1], [], [1, 1], [1]]>
59  >>> tf.ragged.map_flat_values(tf.multiply, rt, rt)
60  <tf.RaggedTensor [[1, 4, 9], [], [16, 25], [36]]>
61  >>> tf.ragged.map_flat_values(tf.add, rt, 5)
62  <tf.RaggedTensor [[6, 7, 8], [], [9, 10], [11]]>
63
64  Example with a non-elementwise operation (note that `map_flat_values` and
65  `map_fn` return different results):
66
67  >>> rt = tf.ragged.constant([[1.0, 3.0], [], [3.0, 6.0, 3.0]])
68  >>> def normalized(x):
69  ...   return x / tf.reduce_sum(x)
70  >>> tf.ragged.map_flat_values(normalized, rt)
71  <tf.RaggedTensor [[0.0625, 0.1875], [], [0.1875, 0.375, 0.1875]]>
72  >>> tf.map_fn(normalized, rt)
73  <tf.RaggedTensor [[0.25, 0.75], [], [0.25, 0.5, 0.25]]>
74
75  Args:
76    op: The operation that should be applied to the RaggedTensor `flat_values`.
77      `op` is typically an element-wise operation (such as math_ops.add), but
78      any operation that preserves the size of the outermost dimension can be
79      used.  I.e., `shape[0]` of the value returned by `op` must match
80      `shape[0]` of the `RaggedTensor`s' `flat_values` tensors.
81    *args: Arguments for `op`.
82    **kwargs: Keyword arguments for `op`.
83
84  Returns:
85    A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all
86    input `RaggedTensor`s.
87  Raises:
88    ValueError: If args contains no `RaggedTensors`, or if the `nested_splits`
89      of the input `RaggedTensor`s are not identical.
90  """
91  # Replace RaggedTensors with their values; and collect the splits tensors
92  # from each RaggedTensor.
93  nested_splits_lists = []
94  flat_values_nrows = []
95  inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists,
96                                                flat_values_nrows)
97  inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists,
98                                                  flat_values_nrows)
99  if not nested_splits_lists:
100    return op(*args, **kwargs)
101  if flat_values_nrows:
102    flat_values_nrows = set(flat_values_nrows)
103    if len(flat_values_nrows) != 1:
104      raise ValueError("Input RaggedTensors' flat_values must all have the "
105                       "same outer-dimension size.  Got sizes: %s" %
106                       flat_values_nrows)
107    flat_values_nrows = flat_values_nrows.pop()  # Get the single element
108  else:
109    flat_values_nrows = None
110
111  split_dtypes = set(splits[0].dtype for splits in nested_splits_lists)
112  if len(split_dtypes) > 1:
113    if not ragged_config.auto_cast_partition_dtype():
114      raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
115                       "use RaggedTensor.with_row_splits_dtype() to convert "
116                       "them to compatible dtypes.")
117
118    nested_splits_lists = [
119        [math_ops.cast(s, dtypes.int64) for s in nested_splits]  # pylint: disable=g-complex-comprehension
120        for nested_splits in nested_splits_lists]
121
122  with ops.control_dependencies(
123      ragged_util.assert_splits_match(nested_splits_lists)):
124    # Delegate to `op`
125    op_output = op(*inner_args, **inner_kwargs)
126    # Check that the result has the expected shape (if known).
127    if flat_values_nrows is not None:
128      if not op_output.shape[:1].is_compatible_with([flat_values_nrows]):
129        raise ValueError(
130            "tf.ragged.map_flat_values requires that the output of `op` have "
131            "the same outer-dimension size as flat_values of any ragged "
132            "inputs. (output shape: %s; expected outer dimension size: %s)" %
133            (op_output.shape, flat_values_nrows))
134    # Compose the result from the transformed values and the splits.
135    return ragged_tensor.RaggedTensor.from_nested_row_splits(
136        op_output, nested_splits_lists[0], validate=False)
137
138
139def _replace_ragged_with_flat_values(value, nested_splits_lists,
140                                     flat_values_nrows):
141  """Replace RaggedTensors with their flat_values, and record their splits.
142
143  Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
144  `flat_values` tensor.  Looks inside lists, tuples, and dicts.
145
146  Appends each `RaggedTensor`'s `nested_splits` to `nested_splits_lists`.
147
148  Args:
149    value: The value that should be transformed by replacing `RaggedTensors`.
150    nested_splits_lists: An output parameter used to record the `nested_splits`
151      for any `RaggedTensors` that were replaced.
152    flat_values_nrows: An output parameter used to record the outer dimension
153      size for each replacement `flat_values` (when known).  Contains a list of
154      int.
155
156  Returns:
157    A copy of `value` with nested `RaggedTensors` replaced by their `values`.
158  """
159  # Base case
160  if ragged_tensor.is_ragged(value):
161    value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
162    nested_splits_lists.append(value.nested_row_splits)
163    nrows = tensor_shape.dimension_at_index(value.flat_values.shape, 0).value
164    if nrows is not None:
165      flat_values_nrows.append(nrows)
166    return value.flat_values
167
168  # Recursion cases
169  def recurse(v):
170    return _replace_ragged_with_flat_values(v, nested_splits_lists,
171                                            flat_values_nrows)
172
173  if isinstance(value, list):
174    return [recurse(v) for v in value]
175  elif isinstance(value, tuple):
176    return tuple(recurse(v) for v in value)
177  elif isinstance(value, dict):
178    return dict((k, recurse(v)) for (k, v) in value.items())
179  else:
180    return value
181