1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support for ragged tensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops.ragged import ragged_config 26from tensorflow.python.ops.ragged import ragged_tensor 27from tensorflow.python.ops.ragged import ragged_util 28from tensorflow.python.util import dispatch 29from tensorflow.python.util.tf_export import tf_export 30 31 32@tf_export("ragged.map_flat_values") 33@dispatch.add_dispatch_support 34def map_flat_values(op, *args, **kwargs): 35 """Applies `op` to the `flat_values` of one or more RaggedTensors. 36 37 Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` 38 tensor (which collapses all ragged dimensions), and then calls `op`. Returns 39 a `RaggedTensor` that is constructed from the input `RaggedTensor`s' 40 `nested_row_splits` and the value returned by the `op`. 41 42 If the input arguments contain multiple `RaggedTensor`s, then they must have 43 identical `nested_row_splits`. 44 45 This operation is generally used to apply elementwise operations to each value 46 in a `RaggedTensor`. 47 48 Warning: `tf.ragged.map_flat_values` does *not* apply `op` to each row of a 49 ragged tensor. This difference is important for non-elementwise operations, 50 such as `tf.reduce_sum`. If you wish to apply a non-elementwise operation to 51 each row of a ragged tensor, use `tf.map_fn` instead. (You may need to 52 specify an `output_signature` when using `tf.map_fn` with ragged tensors.) 53 54 Examples: 55 56 >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) 57 >>> tf.ragged.map_flat_values(tf.ones_like, rt) 58 <tf.RaggedTensor [[1, 1, 1], [], [1, 1], [1]]> 59 >>> tf.ragged.map_flat_values(tf.multiply, rt, rt) 60 <tf.RaggedTensor [[1, 4, 9], [], [16, 25], [36]]> 61 >>> tf.ragged.map_flat_values(tf.add, rt, 5) 62 <tf.RaggedTensor [[6, 7, 8], [], [9, 10], [11]]> 63 64 Example with a non-elementwise operation (note that `map_flat_values` and 65 `map_fn` return different results): 66 67 >>> rt = tf.ragged.constant([[1.0, 3.0], [], [3.0, 6.0, 3.0]]) 68 >>> def normalized(x): 69 ... return x / tf.reduce_sum(x) 70 >>> tf.ragged.map_flat_values(normalized, rt) 71 <tf.RaggedTensor [[0.0625, 0.1875], [], [0.1875, 0.375, 0.1875]]> 72 >>> tf.map_fn(normalized, rt) 73 <tf.RaggedTensor [[0.25, 0.75], [], [0.25, 0.5, 0.25]]> 74 75 Args: 76 op: The operation that should be applied to the RaggedTensor `flat_values`. 77 `op` is typically an element-wise operation (such as math_ops.add), but 78 any operation that preserves the size of the outermost dimension can be 79 used. I.e., `shape[0]` of the value returned by `op` must match 80 `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. 81 *args: Arguments for `op`. 82 **kwargs: Keyword arguments for `op`. 83 84 Returns: 85 A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all 86 input `RaggedTensor`s. 87 Raises: 88 ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` 89 of the input `RaggedTensor`s are not identical. 90 """ 91 # Replace RaggedTensors with their values; and collect the splits tensors 92 # from each RaggedTensor. 93 nested_splits_lists = [] 94 flat_values_nrows = [] 95 inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists, 96 flat_values_nrows) 97 inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists, 98 flat_values_nrows) 99 if not nested_splits_lists: 100 return op(*args, **kwargs) 101 if flat_values_nrows: 102 flat_values_nrows = set(flat_values_nrows) 103 if len(flat_values_nrows) != 1: 104 raise ValueError("Input RaggedTensors' flat_values must all have the " 105 "same outer-dimension size. Got sizes: %s" % 106 flat_values_nrows) 107 flat_values_nrows = flat_values_nrows.pop() # Get the single element 108 else: 109 flat_values_nrows = None 110 111 split_dtypes = set(splits[0].dtype for splits in nested_splits_lists) 112 if len(split_dtypes) > 1: 113 if not ragged_config.auto_cast_partition_dtype(): 114 raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; " 115 "use RaggedTensor.with_row_splits_dtype() to convert " 116 "them to compatible dtypes.") 117 118 nested_splits_lists = [ 119 [math_ops.cast(s, dtypes.int64) for s in nested_splits] # pylint: disable=g-complex-comprehension 120 for nested_splits in nested_splits_lists] 121 122 with ops.control_dependencies( 123 ragged_util.assert_splits_match(nested_splits_lists)): 124 # Delegate to `op` 125 op_output = op(*inner_args, **inner_kwargs) 126 # Check that the result has the expected shape (if known). 127 if flat_values_nrows is not None: 128 if not op_output.shape[:1].is_compatible_with([flat_values_nrows]): 129 raise ValueError( 130 "tf.ragged.map_flat_values requires that the output of `op` have " 131 "the same outer-dimension size as flat_values of any ragged " 132 "inputs. (output shape: %s; expected outer dimension size: %s)" % 133 (op_output.shape, flat_values_nrows)) 134 # Compose the result from the transformed values and the splits. 135 return ragged_tensor.RaggedTensor.from_nested_row_splits( 136 op_output, nested_splits_lists[0], validate=False) 137 138 139def _replace_ragged_with_flat_values(value, nested_splits_lists, 140 flat_values_nrows): 141 """Replace RaggedTensors with their flat_values, and record their splits. 142 143 Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their 144 `flat_values` tensor. Looks inside lists, tuples, and dicts. 145 146 Appends each `RaggedTensor`'s `nested_splits` to `nested_splits_lists`. 147 148 Args: 149 value: The value that should be transformed by replacing `RaggedTensors`. 150 nested_splits_lists: An output parameter used to record the `nested_splits` 151 for any `RaggedTensors` that were replaced. 152 flat_values_nrows: An output parameter used to record the outer dimension 153 size for each replacement `flat_values` (when known). Contains a list of 154 int. 155 156 Returns: 157 A copy of `value` with nested `RaggedTensors` replaced by their `values`. 158 """ 159 # Base case 160 if ragged_tensor.is_ragged(value): 161 value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value) 162 nested_splits_lists.append(value.nested_row_splits) 163 nrows = tensor_shape.dimension_at_index(value.flat_values.shape, 0).value 164 if nrows is not None: 165 flat_values_nrows.append(nrows) 166 return value.flat_values 167 168 # Recursion cases 169 def recurse(v): 170 return _replace_ragged_with_flat_values(v, nested_splits_lists, 171 flat_values_nrows) 172 173 if isinstance(value, list): 174 return [recurse(v) for v in value] 175 elif isinstance(value, tuple): 176 return tuple(recurse(v) for v in value) 177 elif isinstance(value, dict): 178 return dict((k, recurse(v)) for (k, v) in value.items()) 179 else: 180 return value 181