1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support for ragged tensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.ops.ragged import ragged_tensor 23from tensorflow.python.ops.ragged import ragged_util 24from tensorflow.python.util.tf_export import tf_export 25 26 27@tf_export("ragged.map_flat_values") 28def map_flat_values(op, *args, **kwargs): 29 """Applies `op` to the values of one or more RaggedTensors. 30 31 Replaces any `RaggedTensor` in `args` or `kwargs` with its `flat_values` 32 tensor, and then calls `op`. Returns a `RaggedTensor` that is constructed 33 from the input `RaggedTensor`s' `nested_row_splits` and the value returned by 34 the `op`. 35 36 If the input arguments contain multiple `RaggedTensor`s, then they must have 37 identical `nested_row_splits`. 38 39 Examples: 40 41 ```python 42 >>> rt = ragged.constant([[1, 2, 3], [], [4, 5], [6]]) 43 >>> ragged.map_flat_values(tf.ones_like, rt).eval().tolist() 44 [[1, 1, 1], [], [1, 1], [1]] 45 >>> ragged.map_flat_values(tf.multiply, rt, rt).eval().tolist() 46 [[1, 4, 9], [], [16, 25], [36]] 47 >>> ragged.map_flat_values(tf.add, rt, 5).eval().tolist() 48 [[6, 7, 8], [], [9, 10], [11]] 49 ``` 50 51 Args: 52 op: The operation that should be applied to the RaggedTensor `flat_values`. 53 `op` is typically an element-wise operation (such as math_ops.add), but 54 any operation that preserves the size of the outermost dimension can be 55 used. I.e., `shape[0]` of the value returned by `op` must match 56 `shape[0]` of the `RaggedTensor`s' `flat_values` tensors. 57 *args: Arguments for `op`. 58 **kwargs: Keyword arguments for `op`. 59 60 Returns: 61 A `RaggedTensor` whose `ragged_rank` matches the `ragged_rank` of all 62 input `RaggedTensor`s. 63 Raises: 64 ValueError: If args contains no `RaggedTensors`, or if the `nested_splits` 65 of the input `RaggedTensor`s are not identical. 66 """ 67 # Replace RaggedTensors with their values; and collect the splits tensors 68 # from each RaggedTensor. 69 nested_splits_lists = [] 70 inner_args = _replace_ragged_with_flat_values(args, nested_splits_lists) 71 inner_kwargs = _replace_ragged_with_flat_values(kwargs, nested_splits_lists) 72 if not nested_splits_lists: 73 return op(*args, **kwargs) 74 75 with ops.control_dependencies( 76 ragged_util.assert_splits_match(nested_splits_lists)): 77 # Delegate to op, and then compose the result from the transformed values 78 # and the splits. 79 return ragged_tensor.RaggedTensor.from_nested_row_splits( 80 op(*inner_args, **inner_kwargs), nested_splits_lists[0]) 81 82 83def _replace_ragged_with_flat_values(value, nested_splits_lists): 84 """Replace RaggedTensors with their flat_values, and record their splits. 85 86 Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their 87 `flat_values` tensor. Looks inside lists, tuples, and dicts. 88 89 Appends each `RaggedTensor`'s `nested_splits` to `nested_splits_lists`. 90 91 Args: 92 value: The value that should be transformed by replacing `RaggedTensors`. 93 nested_splits_lists: An output parameter used to record the `nested_splits` 94 for any `RaggedTensors` that were replaced. 95 96 Returns: 97 A copy of `value` with nested `RaggedTensors` replaced by their `values`. 98 """ 99 # Base case 100 if ragged_tensor.is_ragged(value): 101 value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value) 102 nested_splits_lists.append(value.nested_row_splits) 103 return value.flat_values 104 105 # Recursion cases 106 def recurse(v): 107 return _replace_ragged_with_flat_values(v, nested_splits_lists) 108 109 if isinstance(value, list): 110 return [recurse(v) for v in value] 111 elif isinstance(value, tuple): 112 return tuple(recurse(v) for v in value) 113 elif isinstance(value, dict): 114 return dict((k, recurse(v)) for (k, v) in value.items()) 115 else: 116 return value 117