1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional operations for RaggedTensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20from tensorflow.python.ops.ragged import ragged_tensor 21from tensorflow.python.util import nest 22from tensorflow.python.util.lazy_loader import LazyLoader 23 24 25map_fn_lib = LazyLoader( 26 "map_fn_lib", globals(), 27 "tensorflow.python.ops.map_fn") 28 29 30def map_fn(fn, 31 elems, 32 dtype=None, 33 parallel_iterations=None, 34 back_prop=True, 35 swap_memory=False, 36 infer_shape=True, 37 name=None): 38 """map on the list of tensors unpacked from `elems` on dimension 0. 39 40 The simplest version of `map_fn` repeatedly applies the callable `fn` to a 41 sequence of elements from first to last. The elements are made of the 42 tensors unpacked from `elems`. `dtype` is the data type of the return 43 value of `fn`. Users must provide `dtype` if it is different from 44 the data type of `elems`. 45 46 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 47 of the result tensor is `[values.shape[0]] + fn(values[0]).shape`. 48 49 This method also allows multi-arity `elems` and output of `fn`. If `elems` 50 is a (possibly nested) list or tuple of tensors, then each of these tensors 51 must have a matching first (unpack) dimension. The signature of `fn` may 52 match the structure of `elems`. That is, if `elems` is 53 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 54 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 55 56 Furthermore, `fn` may emit a different structure than its input. For example, 57 `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, 58 the `dtype` parameter is not optional: `dtype` must be a type or (possibly 59 nested) tuple of types matching the output of `fn`. 60 61 To apply a functional operation to the nonzero elements of a SparseTensor 62 one of the following methods is recommended. First, if the function is 63 expressible as TensorFlow ops, use 64 65 ```python 66 result = SparseTensor(input.indices, fn(input.values), input.dense_shape) 67 ``` 68 69 If, however, the function is not expressible as a TensorFlow op, then use 70 71 ```python 72 result = SparseTensor( 73 input.indices, map_fn(fn, input.values), input.dense_shape) 74 ``` 75 76 instead. 77 78 When executing eagerly, map_fn does not execute in parallel even if 79 `parallel_iterations` is set to a value > 1. You can still get the 80 performance benefits of running a function in parallel by using the 81 `tf.contrib.eager.defun` decorator, 82 83 ```python 84 # Assume the function being used in map_fn is fn. 85 # To ensure map_fn calls fn in parallel, use the defun decorator. 86 @tf.contrib.eager.defun 87 def func(tensor): 88 return tf.map_fn(fn, tensor) 89 ``` 90 91 Note that if you use the defun decorator, any non-TensorFlow Python code 92 that you may have written in your function won't get executed. See 93 `tf.contrib.eager.defun` for more details. The recommendation would be to 94 debug without defun but switch to defun to get performance benefits of 95 running map_fn in parallel. 96 97 Args: 98 fn: The callable to be performed. It accepts one argument, which will have 99 the same (possibly nested) structure as `elems`. Its output must have the 100 same structure as `dtype` if one is provided, otherwise it must have the 101 same structure as `elems`. 102 elems: A tensor or (possibly nested) sequence of tensors, each of which will 103 be unpacked along their first dimension. The nested sequence of the 104 resulting slices will be applied to `fn`. 105 dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure 106 of Tensors differing from the structure of `elems`, then `dtype` is not 107 optional and must have the same structure as the output of `fn`. Use 108 `RaggedTensorType` to declare an output of type `RaggedTensor`. 109 parallel_iterations: (optional) The number of iterations allowed to run in 110 parallel. When graph building, the default value is 10. While executing 111 eagerly, the default value is set to 1. 112 back_prop: (optional) True enables support for back propagation. 113 swap_memory: (optional) True enables GPU-CPU memory swapping. 114 infer_shape: (optional) False disables tests for consistent output shapes. 115 name: (optional) Name prefix for the returned tensors. 116 117 Returns: 118 A possibly nested sequence of potentially ragged tensors. Each 119 tensor packs the results of applying `fn` to tensors unpacked from `elems` 120 along the first dimension, from first to last. 121 122 Raises: 123 TypeError: if `fn` is not callable or the structure of the output of 124 `fn` and `dtype` do not match, or if elems is a SparseTensor. 125 ValueError: if the lengths of the output of `fn` and `dtype` do not match. 126 127 #### Examples: 128 129 ```python 130 elems = np.array([1, 2, 3, 4, 5, 6]) 131 squares = map_fn(lambda x: x * x, elems) 132 # squares == [1, 4, 9, 16, 25, 36] 133 ``` 134 135 ```python 136 elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) 137 alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) 138 # alternate == [-1, 2, -3] 139 ``` 140 141 ```python 142 elems = np.array([1, 2, 3]) 143 alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) 144 # alternates[0] == [1, 2, 3] 145 # alternates[1] == [-1, -2, -3] 146 ``` 147 148 ```python 149 elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]]) 150 mean = map_fn(tf.reduce_mean, elems) 151 # mean == [2, 4, 6] 152 ``` 153 154 ```python 155 elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]], dtype=tf.int64) 156 out = map_fn(fn=lambda x: x+1, elems, 157 dtype=ragged.RaggedTensorType(type=tf.int64, ragged_rank=0)) 158 # out = tf.ragged.constant([[2, 3, 4], [5, 6], [7, 8]]) 159 ``` 160 """ 161 if dtype is None: 162 dtype = nest.map_structure(lambda e: e.dtype, elems) 163 dtype = nest.map_structure(_ragged_type_to_spec, dtype) 164 return map_fn_lib.map_fn(fn, 165 elems, 166 dtype, 167 parallel_iterations, 168 back_prop, 169 swap_memory, 170 infer_shape, 171 name) 172 173 174def _ragged_type_to_spec(t): 175 if isinstance(t, ragged_tensor.RaggedTensorType): 176 # Note: need to adjust ragged_rank by 1, since RaggedTensorSpec gives the 177 # type for the mapped `fn` output, but RaggedTensorType gives the type for 178 # the result of stacking the mapped `fn` outputs. 179 return ragged_tensor.RaggedTensorSpec( 180 None, t.dtype, t.ragged_rank - 1, t.row_splits_dtype) 181 else: 182 return t 183